mxnet
op.h
Go to the documentation of this file.
1 
8 #ifndef MXNET_CPP_OP_H_
9 #define MXNET_CPP_OP_H_
10 
11 #include <string>
12 #include <vector>
13 #include "mxnet-cpp/base.h"
14 #include "mxnet-cpp/shape.h"
15 #include "mxnet-cpp/op_util.h"
16 #include "mxnet-cpp/operator.h"
17 #include "dmlc/optional.h"
18 
19 namespace mxnet {
20 namespace cpp {
21 
62 inline Symbol khatri_rao(const std::string& symbol_name,
63  const std::vector<Symbol>& args) {
64  return Operator("khatri_rao")
65 (args)
66  .CreateSymbol(symbol_name);
67 }
68 
84 inline Symbol Custom(const std::string& symbol_name,
85  const std::vector<Symbol>& data,
86  const std::string& op_type) {
87  return Operator("Custom")
88 (data)
89  .CreateSymbol(symbol_name);
90 }
91 
114 inline Symbol broadcast_power(const std::string& symbol_name,
115  Symbol lhs,
116  Symbol rhs) {
117  return Operator("broadcast_power")
118  .SetInput("lhs", lhs)
119  .SetInput("rhs", rhs)
120  .CreateSymbol(symbol_name);
121 }
122 
147 inline Symbol broadcast_maximum(const std::string& symbol_name,
148  Symbol lhs,
149  Symbol rhs) {
150  return Operator("broadcast_maximum")
151  .SetInput("lhs", lhs)
152  .SetInput("rhs", rhs)
153  .CreateSymbol(symbol_name);
154 }
155 
180 inline Symbol broadcast_minimum(const std::string& symbol_name,
181  Symbol lhs,
182  Symbol rhs) {
183  return Operator("broadcast_minimum")
184  .SetInput("lhs", lhs)
185  .SetInput("rhs", rhs)
186  .CreateSymbol(symbol_name);
187 }
188 
219 inline Symbol broadcast_hypot(const std::string& symbol_name,
220  Symbol lhs,
221  Symbol rhs) {
222  return Operator("broadcast_hypot")
223  .SetInput("lhs", lhs)
224  .SetInput("rhs", rhs)
225  .CreateSymbol(symbol_name);
226 }
227 
302 inline Symbol Reshape(const std::string& symbol_name,
303  Symbol data,
304  Shape shape = Shape(),
305  bool reverse = false,
306  Shape target_shape = Shape(),
307  bool keep_highest = false) {
308  return Operator("Reshape")
309  .SetParam("shape", shape)
310  .SetParam("reverse", reverse)
311  .SetParam("target_shape", target_shape)
312  .SetParam("keep_highest", keep_highest)
313  .SetInput("data", data)
314  .CreateSymbol(symbol_name);
315 }
316 
347 inline Symbol Flatten(const std::string& symbol_name,
348  Symbol data) {
349  return Operator("Flatten")
350  .SetInput("data", data)
351  .CreateSymbol(symbol_name);
352 }
353 
390 inline Symbol transpose(const std::string& symbol_name,
391  Symbol data,
392  Shape axes = Shape()) {
393  return Operator("transpose")
394  .SetParam("axes", axes)
395  .SetInput("data", data)
396  .CreateSymbol(symbol_name);
397 }
398 
414 inline Symbol expand_dims(const std::string& symbol_name,
415  Symbol data,
416  int axis) {
417  return Operator("expand_dims")
418  .SetParam("axis", axis)
419  .SetInput("data", data)
420  .CreateSymbol(symbol_name);
421 }
422 
478 inline Symbol slice(const std::string& symbol_name,
479  Symbol data,
480  Shape begin,
481  Shape end,
482  Shape step = Shape()) {
483  return Operator("slice")
484  .SetParam("begin", begin)
485  .SetParam("end", end)
486  .SetParam("step", step)
487  .SetInput("data", data)
488  .CreateSymbol(symbol_name);
489 }
490 
523 inline Symbol slice_axis(const std::string& symbol_name,
524  Symbol data,
525  int axis,
526  int begin,
527  dmlc::optional<int> end) {
528  return Operator("slice_axis")
529  .SetParam("axis", axis)
530  .SetParam("begin", begin)
531  .SetParam("end", end)
532  .SetInput("data", data)
533  .CreateSymbol(symbol_name);
534 }
535 
570 inline Symbol clip(const std::string& symbol_name,
571  Symbol data,
572  mx_float a_min,
573  mx_float a_max) {
574  return Operator("clip")
575  .SetParam("a_min", a_min)
576  .SetParam("a_max", a_max)
577  .SetInput("data", data)
578  .CreateSymbol(symbol_name);
579 }
580 
615 inline Symbol repeat(const std::string& symbol_name,
616  Symbol data,
617  int repeats,
618  dmlc::optional<int> axis = dmlc::optional<int>()) {
619  return Operator("repeat")
620  .SetParam("repeats", repeats)
621  .SetParam("axis", axis)
622  .SetInput("data", data)
623  .CreateSymbol(symbol_name);
624 }
625 
671 inline Symbol tile(const std::string& symbol_name,
672  Symbol data,
673  Shape reps) {
674  return Operator("tile")
675  .SetParam("reps", reps)
676  .SetInput("data", data)
677  .CreateSymbol(symbol_name);
678 }
679 
703 inline Symbol reverse(const std::string& symbol_name,
704  Symbol data,
705  Shape axis) {
706  return Operator("reverse")
707  .SetParam("axis", axis)
708  .SetInput("data", data)
709  .CreateSymbol(symbol_name);
710 }
711 
735 inline Symbol stack(const std::string& symbol_name,
736  const std::vector<Symbol>& data,
737  int num_args,
738  int axis = 0) {
739  return Operator("stack")
740  .SetParam("num_args", num_args)
741  .SetParam("axis", axis)
742 (data)
743  .CreateSymbol(symbol_name);
744 }
745 
769 inline Symbol zeros_like(const std::string& symbol_name,
770  Symbol data) {
771  return Operator("zeros_like")
772  .SetInput("data", data)
773  .CreateSymbol(symbol_name);
774 }
775 
793 inline Symbol ones_like(const std::string& symbol_name,
794  Symbol data) {
795  return Operator("ones_like")
796  .SetInput("data", data)
797  .CreateSymbol(symbol_name);
798 }
799 
827 inline Symbol broadcast_add(const std::string& symbol_name,
828  Symbol lhs,
829  Symbol rhs) {
830  return Operator("broadcast_add")
831  .SetInput("lhs", lhs)
832  .SetInput("rhs", rhs)
833  .CreateSymbol(symbol_name);
834 }
835 
863 inline Symbol broadcast_sub(const std::string& symbol_name,
864  Symbol lhs,
865  Symbol rhs) {
866  return Operator("broadcast_sub")
867  .SetInput("lhs", lhs)
868  .SetInput("rhs", rhs)
869  .CreateSymbol(symbol_name);
870 }
871 
894 inline Symbol broadcast_mul(const std::string& symbol_name,
895  Symbol lhs,
896  Symbol rhs) {
897  return Operator("broadcast_mul")
898  .SetInput("lhs", lhs)
899  .SetInput("rhs", rhs)
900  .CreateSymbol(symbol_name);
901 }
902 
925 inline Symbol broadcast_div(const std::string& symbol_name,
926  Symbol lhs,
927  Symbol rhs) {
928  return Operator("broadcast_div")
929  .SetInput("lhs", lhs)
930  .SetInput("rhs", rhs)
931  .CreateSymbol(symbol_name);
932 }
933 
956 inline Symbol broadcast_mod(const std::string& symbol_name,
957  Symbol lhs,
958  Symbol rhs) {
959  return Operator("broadcast_mod")
960  .SetInput("lhs", lhs)
961  .SetInput("rhs", rhs)
962  .CreateSymbol(symbol_name);
963 }
964 
985 inline Symbol add_n(const std::string& symbol_name,
986  const std::vector<Symbol>& args) {
987  return Operator("add_n")
988 (args)
989  .CreateSymbol(symbol_name);
990 }
991 
1023 inline Symbol argmax(const std::string& symbol_name,
1024  Symbol data,
1025  dmlc::optional<int> axis = dmlc::optional<int>(),
1026  bool keepdims = false) {
1027  return Operator("argmax")
1028  .SetParam("axis", axis)
1029  .SetParam("keepdims", keepdims)
1030  .SetInput("data", data)
1031  .CreateSymbol(symbol_name);
1032 }
1033 
1065 inline Symbol argmin(const std::string& symbol_name,
1066  Symbol data,
1067  dmlc::optional<int> axis = dmlc::optional<int>(),
1068  bool keepdims = false) {
1069  return Operator("argmin")
1070  .SetParam("axis", axis)
1071  .SetParam("keepdims", keepdims)
1072  .SetInput("data", data)
1073  .CreateSymbol(symbol_name);
1074 }
1075 
1098 inline Symbol argmax_channel(const std::string& symbol_name,
1099  Symbol data) {
1100  return Operator("argmax_channel")
1101  .SetInput("data", data)
1102  .CreateSymbol(symbol_name);
1103 }
1104 
1150 inline Symbol pick(const std::string& symbol_name,
1151  Symbol data,
1152  Symbol index,
1153  dmlc::optional<int> axis = dmlc::optional<int>(),
1154  bool keepdims = false) {
1155  return Operator("pick")
1156  .SetParam("axis", axis)
1157  .SetParam("keepdims", keepdims)
1158  .SetInput("data", data)
1159  .SetInput("index", index)
1160  .CreateSymbol(symbol_name);
1161 }
1162 
1203 inline Symbol dot(const std::string& symbol_name,
1204  Symbol lhs,
1205  Symbol rhs,
1206  bool transpose_a = false,
1207  bool transpose_b = false) {
1208  return Operator("dot")
1209  .SetParam("transpose_a", transpose_a)
1210  .SetParam("transpose_b", transpose_b)
1211  .SetInput("lhs", lhs)
1212  .SetInput("rhs", rhs)
1213  .CreateSymbol(symbol_name);
1214 }
1215 
1238 inline Symbol batch_dot(const std::string& symbol_name,
1239  Symbol lhs,
1240  Symbol rhs,
1241  bool transpose_a = false,
1242  bool transpose_b = false) {
1243  return Operator("batch_dot")
1244  .SetParam("transpose_a", transpose_a)
1245  .SetParam("transpose_b", transpose_b)
1246  .SetInput("lhs", lhs)
1247  .SetInput("rhs", rhs)
1248  .CreateSymbol(symbol_name);
1249 }
1250 
1269 inline Symbol relu(const std::string& symbol_name,
1270  Symbol data) {
1271  return Operator("relu")
1272  .SetInput("data", data)
1273  .CreateSymbol(symbol_name);
1274 }
1275 
1291 inline Symbol sigmoid(const std::string& symbol_name,
1292  Symbol data) {
1293  return Operator("sigmoid")
1294  .SetInput("data", data)
1295  .CreateSymbol(symbol_name);
1296 }
1297 
1331 inline Symbol BlockGrad(const std::string& symbol_name,
1332  Symbol data) {
1333  return Operator("BlockGrad")
1334  .SetInput("data", data)
1335  .CreateSymbol(symbol_name);
1336 }
1337 
1367 inline Symbol make_loss(const std::string& symbol_name,
1368  Symbol data) {
1369  return Operator("make_loss")
1370  .SetInput("data", data)
1371  .CreateSymbol(symbol_name);
1372 }
1373 
1381 inline Symbol reshape_like(const std::string& symbol_name,
1382  Symbol lhs,
1383  Symbol rhs) {
1384  return Operator("reshape_like")
1385  .SetInput("lhs", lhs)
1386  .SetInput("rhs", rhs)
1387  .CreateSymbol(symbol_name);
1388 }
1389 
1392 enum class CastDtype {
1393  kFloat16 = 0,
1394  kFloat32 = 1,
1395  kFloat64 = 2,
1396  kInt32 = 3,
1397  kUint8 = 4
1398 };
1399 
1419 inline Symbol Cast(const std::string& symbol_name,
1420  Symbol data,
1421  CastDtype dtype) {
1422  static const char *CastDtypeValues[] = {
1423  "float16",
1424  "float32",
1425  "float64",
1426  "int32",
1427  "uint8"
1428  };
1429  return Operator("Cast")
1430  .SetParam("dtype", CastDtypeValues[int(dtype)])
1431  .SetInput("data", data)
1432  .CreateSymbol(symbol_name);
1433 }
1434 
1449 inline Symbol negative(const std::string& symbol_name,
1450  Symbol data) {
1451  return Operator("negative")
1452  .SetInput("data", data)
1453  .CreateSymbol(symbol_name);
1454 }
1455 
1472 inline Symbol reciprocal(const std::string& symbol_name,
1473  Symbol data) {
1474  return Operator("reciprocal")
1475  .SetInput("data", data)
1476  .CreateSymbol(symbol_name);
1477 }
1478 
1498 inline Symbol abs(const std::string& symbol_name,
1499  Symbol data) {
1500  return Operator("abs")
1501  .SetInput("data", data)
1502  .CreateSymbol(symbol_name);
1503 }
1504 
1524 inline Symbol sign(const std::string& symbol_name,
1525  Symbol data) {
1526  return Operator("sign")
1527  .SetInput("data", data)
1528  .CreateSymbol(symbol_name);
1529 }
1530 
1550 inline Symbol round(const std::string& symbol_name,
1551  Symbol data) {
1552  return Operator("round")
1553  .SetInput("data", data)
1554  .CreateSymbol(symbol_name);
1555 }
1556 
1580 inline Symbol rint(const std::string& symbol_name,
1581  Symbol data) {
1582  return Operator("rint")
1583  .SetInput("data", data)
1584  .CreateSymbol(symbol_name);
1585 }
1586 
1608 inline Symbol ceil(const std::string& symbol_name,
1609  Symbol data) {
1610  return Operator("ceil")
1611  .SetInput("data", data)
1612  .CreateSymbol(symbol_name);
1613 }
1614 
1636 inline Symbol floor(const std::string& symbol_name,
1637  Symbol data) {
1638  return Operator("floor")
1639  .SetInput("data", data)
1640  .CreateSymbol(symbol_name);
1641 }
1642 
1665 inline Symbol trunc(const std::string& symbol_name,
1666  Symbol data) {
1667  return Operator("trunc")
1668  .SetInput("data", data)
1669  .CreateSymbol(symbol_name);
1670 }
1671 
1692 inline Symbol fix(const std::string& symbol_name,
1693  Symbol data) {
1694  return Operator("fix")
1695  .SetInput("data", data)
1696  .CreateSymbol(symbol_name);
1697 }
1698 
1722 inline Symbol square(const std::string& symbol_name,
1723  Symbol data) {
1724  return Operator("square")
1725  .SetInput("data", data)
1726  .CreateSymbol(symbol_name);
1727 }
1728 
1751 inline Symbol sqrt(const std::string& symbol_name,
1752  Symbol data) {
1753  return Operator("sqrt")
1754  .SetInput("data", data)
1755  .CreateSymbol(symbol_name);
1756 }
1757 
1777 inline Symbol rsqrt(const std::string& symbol_name,
1778  Symbol data) {
1779  return Operator("rsqrt")
1780  .SetInput("data", data)
1781  .CreateSymbol(symbol_name);
1782 }
1783 
1801 inline Symbol cbrt(const std::string& symbol_name,
1802  Symbol data) {
1803  return Operator("cbrt")
1804  .SetInput("data", data)
1805  .CreateSymbol(symbol_name);
1806 }
1807 
1825 inline Symbol rcbrt(const std::string& symbol_name,
1826  Symbol data) {
1827  return Operator("rcbrt")
1828  .SetInput("data", data)
1829  .CreateSymbol(symbol_name);
1830 }
1831 
1851 inline Symbol exp(const std::string& symbol_name,
1852  Symbol data) {
1853  return Operator("exp")
1854  .SetInput("data", data)
1855  .CreateSymbol(symbol_name);
1856 }
1857 
1872 inline Symbol log(const std::string& symbol_name,
1873  Symbol data) {
1874  return Operator("log")
1875  .SetInput("data", data)
1876  .CreateSymbol(symbol_name);
1877 }
1878 
1893 inline Symbol log10(const std::string& symbol_name,
1894  Symbol data) {
1895  return Operator("log10")
1896  .SetInput("data", data)
1897  .CreateSymbol(symbol_name);
1898 }
1899 
1914 inline Symbol log2(const std::string& symbol_name,
1915  Symbol data) {
1916  return Operator("log2")
1917  .SetInput("data", data)
1918  .CreateSymbol(symbol_name);
1919 }
1920 
1939 inline Symbol log1p(const std::string& symbol_name,
1940  Symbol data) {
1941  return Operator("log1p")
1942  .SetInput("data", data)
1943  .CreateSymbol(symbol_name);
1944 }
1945 
1963 inline Symbol expm1(const std::string& symbol_name,
1964  Symbol data) {
1965  return Operator("expm1")
1966  .SetInput("data", data)
1967  .CreateSymbol(symbol_name);
1968 }
1969 
1981 inline Symbol gamma(const std::string& symbol_name,
1982  Symbol data) {
1983  return Operator("gamma")
1984  .SetInput("data", data)
1985  .CreateSymbol(symbol_name);
1986 }
1987 
1999 inline Symbol gammaln(const std::string& symbol_name,
2000  Symbol data) {
2001  return Operator("gammaln")
2002  .SetInput("data", data)
2003  .CreateSymbol(symbol_name);
2004 }
2005 
2064 inline Symbol sum(const std::string& symbol_name,
2065  Symbol data,
2066  Shape axis = Shape(),
2067  bool keepdims = false,
2068  bool exclude = false) {
2069  return Operator("sum")
2070  .SetParam("axis", axis)
2071  .SetParam("keepdims", keepdims)
2072  .SetParam("exclude", exclude)
2073  .SetInput("data", data)
2074  .CreateSymbol(symbol_name);
2075 }
2076 
2101 inline Symbol mean(const std::string& symbol_name,
2102  Symbol data,
2103  Shape axis = Shape(),
2104  bool keepdims = false,
2105  bool exclude = false) {
2106  return Operator("mean")
2107  .SetParam("axis", axis)
2108  .SetParam("keepdims", keepdims)
2109  .SetParam("exclude", exclude)
2110  .SetInput("data", data)
2111  .CreateSymbol(symbol_name);
2112 }
2113 
2138 inline Symbol prod(const std::string& symbol_name,
2139  Symbol data,
2140  Shape axis = Shape(),
2141  bool keepdims = false,
2142  bool exclude = false) {
2143  return Operator("prod")
2144  .SetParam("axis", axis)
2145  .SetParam("keepdims", keepdims)
2146  .SetParam("exclude", exclude)
2147  .SetInput("data", data)
2148  .CreateSymbol(symbol_name);
2149 }
2150 
2177 inline Symbol nansum(const std::string& symbol_name,
2178  Symbol data,
2179  Shape axis = Shape(),
2180  bool keepdims = false,
2181  bool exclude = false) {
2182  return Operator("nansum")
2183  .SetParam("axis", axis)
2184  .SetParam("keepdims", keepdims)
2185  .SetParam("exclude", exclude)
2186  .SetInput("data", data)
2187  .CreateSymbol(symbol_name);
2188 }
2189 
2216 inline Symbol nanprod(const std::string& symbol_name,
2217  Symbol data,
2218  Shape axis = Shape(),
2219  bool keepdims = false,
2220  bool exclude = false) {
2221  return Operator("nanprod")
2222  .SetParam("axis", axis)
2223  .SetParam("keepdims", keepdims)
2224  .SetParam("exclude", exclude)
2225  .SetInput("data", data)
2226  .CreateSymbol(symbol_name);
2227 }
2228 
2253 inline Symbol max(const std::string& symbol_name,
2254  Symbol data,
2255  Shape axis = Shape(),
2256  bool keepdims = false,
2257  bool exclude = false) {
2258  return Operator("max")
2259  .SetParam("axis", axis)
2260  .SetParam("keepdims", keepdims)
2261  .SetParam("exclude", exclude)
2262  .SetInput("data", data)
2263  .CreateSymbol(symbol_name);
2264 }
2265 
2290 inline Symbol min(const std::string& symbol_name,
2291  Symbol data,
2292  Shape axis = Shape(),
2293  bool keepdims = false,
2294  bool exclude = false) {
2295  return Operator("min")
2296  .SetParam("axis", axis)
2297  .SetParam("keepdims", keepdims)
2298  .SetParam("exclude", exclude)
2299  .SetInput("data", data)
2300  .CreateSymbol(symbol_name);
2301 }
2302 
2332 inline Symbol broadcast_axis(const std::string& symbol_name,
2333  Symbol data,
2334  Shape axis = Shape(),
2335  Shape size = Shape()) {
2336  return Operator("broadcast_axis")
2337  .SetParam("axis", axis)
2338  .SetParam("size", size)
2339  .SetInput("data", data)
2340  .CreateSymbol(symbol_name);
2341 }
2342 
2371 inline Symbol broadcast_to(const std::string& symbol_name,
2372  Symbol data,
2373  Shape shape = Shape()) {
2374  return Operator("broadcast_to")
2375  .SetParam("shape", shape)
2376  .SetInput("data", data)
2377  .CreateSymbol(symbol_name);
2378 }
2379 
2405 inline Symbol norm(const std::string& symbol_name,
2406  Symbol data) {
2407  return Operator("norm")
2408  .SetInput("data", data)
2409  .CreateSymbol(symbol_name);
2410 }
2411 
2417 enum class TopkRetTyp {
2418  kBoth = 0,
2419  kIndices = 1,
2420  kMask = 2,
2421  kValue = 3
2422 };
2423 
2465 inline Symbol topk(const std::string& symbol_name,
2466  Symbol data,
2467  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2468  int k = 1,
2469  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
2470  bool is_ascend = false) {
2471  static const char *TopkRetTypValues[] = {
2472  "both",
2473  "indices",
2474  "mask",
2475  "value"
2476  };
2477  return Operator("topk")
2478  .SetParam("axis", axis)
2479  .SetParam("k", k)
2480  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
2481  .SetParam("is_ascend", is_ascend)
2482  .SetInput("data", data)
2483  .CreateSymbol(symbol_name);
2484 }
2485 
2518 inline Symbol sort(const std::string& symbol_name,
2519  Symbol data,
2520  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2521  bool is_ascend = true) {
2522  return Operator("sort")
2523  .SetParam("axis", axis)
2524  .SetParam("is_ascend", is_ascend)
2525  .SetInput("data", data)
2526  .CreateSymbol(symbol_name);
2527 }
2528 
2559 inline Symbol argsort(const std::string& symbol_name,
2560  Symbol data,
2561  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2562  bool is_ascend = true) {
2563  return Operator("argsort")
2564  .SetParam("axis", axis)
2565  .SetParam("is_ascend", is_ascend)
2566  .SetInput("data", data)
2567  .CreateSymbol(symbol_name);
2568 }
2569 
2585 inline Symbol elemwise_add(const std::string& symbol_name,
2586  Symbol lhs,
2587  Symbol rhs) {
2588  return Operator("elemwise_add")
2589  .SetInput("lhs", lhs)
2590  .SetInput("rhs", rhs)
2591  .CreateSymbol(symbol_name);
2592 }
2593 
2609 inline Symbol elemwise_sub(const std::string& symbol_name,
2610  Symbol lhs,
2611  Symbol rhs) {
2612  return Operator("elemwise_sub")
2613  .SetInput("lhs", lhs)
2614  .SetInput("rhs", rhs)
2615  .CreateSymbol(symbol_name);
2616 }
2617 
2636 inline Symbol elemwise_mul(const std::string& symbol_name,
2637  Symbol lhs,
2638  Symbol rhs) {
2639  return Operator("elemwise_mul")
2640  .SetInput("lhs", lhs)
2641  .SetInput("rhs", rhs)
2642  .CreateSymbol(symbol_name);
2643 }
2644 
2656 inline Symbol elemwise_div(const std::string& symbol_name,
2657  Symbol lhs,
2658  Symbol rhs) {
2659  return Operator("elemwise_div")
2660  .SetInput("lhs", lhs)
2661  .SetInput("rhs", rhs)
2662  .CreateSymbol(symbol_name);
2663 }
2664 
2667 enum class EmbeddingDtype {
2668  kFloat16 = 0,
2669  kFloat32 = 1,
2670  kFloat64 = 2,
2671  kInt32 = 3,
2672  kUint8 = 4
2673 };
2674 
2726 inline Symbol Embedding(const std::string& symbol_name,
2727  Symbol data,
2728  Symbol weight,
2729  int input_dim,
2730  int output_dim,
2732  static const char *EmbeddingDtypeValues[] = {
2733  "float16",
2734  "float32",
2735  "float64",
2736  "int32",
2737  "uint8"
2738  };
2739  return Operator("Embedding")
2740  .SetParam("input_dim", input_dim)
2741  .SetParam("output_dim", output_dim)
2742  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
2743  .SetInput("data", data)
2744  .SetInput("weight", weight)
2745  .CreateSymbol(symbol_name);
2746 }
2747 
2752 enum class TakeMode {
2753  kClip = 0,
2754  kRaise = 1,
2755  kWrap = 2
2756 };
2757 
2797 inline Symbol take(const std::string& symbol_name,
2798  Symbol a,
2799  Symbol indices,
2800  int axis = 0,
2801  TakeMode mode = TakeMode::kClip) {
2802  static const char *TakeModeValues[] = {
2803  "clip",
2804  "raise",
2805  "wrap"
2806  };
2807  return Operator("take")
2808  .SetParam("axis", axis)
2809  .SetParam("mode", TakeModeValues[int(mode)])
2810  .SetInput("a", a)
2811  .SetInput("indices", indices)
2812  .CreateSymbol(symbol_name);
2813 }
2814 
2843 inline Symbol batch_take(const std::string& symbol_name,
2844  Symbol a,
2845  Symbol indices) {
2846  return Operator("batch_take")
2847  .SetInput("a", a)
2848  .SetInput("indices", indices)
2849  .CreateSymbol(symbol_name);
2850 }
2851 
2854 enum class One_hotDtype {
2855  kFloat16 = 0,
2856  kFloat32 = 1,
2857  kFloat64 = 2,
2858  kInt32 = 3,
2859  kUint8 = 4
2860 };
2861 
2906 inline Symbol one_hot(const std::string& symbol_name,
2907  Symbol indices,
2908  int depth,
2909  double on_value = 1,
2910  double off_value = 0,
2912  static const char *One_hotDtypeValues[] = {
2913  "float16",
2914  "float32",
2915  "float64",
2916  "int32",
2917  "uint8"
2918  };
2919  return Operator("one_hot")
2920  .SetParam("depth", depth)
2921  .SetParam("on_value", on_value)
2922  .SetParam("off_value", off_value)
2923  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
2924  .SetInput("indices", indices)
2925  .CreateSymbol(symbol_name);
2926 }
2927 
2955 inline Symbol gather_nd(const std::string& symbol_name,
2956  Symbol data,
2957  Symbol indices) {
2958  return Operator("gather_nd")
2959  .SetInput("data", data)
2960  .SetInput("indices", indices)
2961  .CreateSymbol(symbol_name);
2962 }
2963 
3000 inline Symbol scatter_nd(const std::string& symbol_name,
3001  Symbol data,
3002  Symbol indices,
3003  Shape shape) {
3004  return Operator("scatter_nd")
3005  .SetParam("shape", shape)
3006  .SetInput("data", data)
3007  .SetInput("indices", indices)
3008  .CreateSymbol(symbol_name);
3009 }
3010 
3033 inline Symbol broadcast_equal(const std::string& symbol_name,
3034  Symbol lhs,
3035  Symbol rhs) {
3036  return Operator("broadcast_equal")
3037  .SetInput("lhs", lhs)
3038  .SetInput("rhs", rhs)
3039  .CreateSymbol(symbol_name);
3040 }
3041 
3064 inline Symbol broadcast_not_equal(const std::string& symbol_name,
3065  Symbol lhs,
3066  Symbol rhs) {
3067  return Operator("broadcast_not_equal")
3068  .SetInput("lhs", lhs)
3069  .SetInput("rhs", rhs)
3070  .CreateSymbol(symbol_name);
3071 }
3072 
3095 inline Symbol broadcast_greater(const std::string& symbol_name,
3096  Symbol lhs,
3097  Symbol rhs) {
3098  return Operator("broadcast_greater")
3099  .SetInput("lhs", lhs)
3100  .SetInput("rhs", rhs)
3101  .CreateSymbol(symbol_name);
3102 }
3103 
3126 inline Symbol broadcast_greater_equal(const std::string& symbol_name,
3127  Symbol lhs,
3128  Symbol rhs) {
3129  return Operator("broadcast_greater_equal")
3130  .SetInput("lhs", lhs)
3131  .SetInput("rhs", rhs)
3132  .CreateSymbol(symbol_name);
3133 }
3134 
3157 inline Symbol broadcast_lesser(const std::string& symbol_name,
3158  Symbol lhs,
3159  Symbol rhs) {
3160  return Operator("broadcast_lesser")
3161  .SetInput("lhs", lhs)
3162  .SetInput("rhs", rhs)
3163  .CreateSymbol(symbol_name);
3164 }
3165 
3188 inline Symbol broadcast_lesser_equal(const std::string& symbol_name,
3189  Symbol lhs,
3190  Symbol rhs) {
3191  return Operator("broadcast_lesser_equal")
3192  .SetInput("lhs", lhs)
3193  .SetInput("rhs", rhs)
3194  .CreateSymbol(symbol_name);
3195 }
3196 
3213 inline Symbol where(const std::string& symbol_name,
3214  Symbol condition,
3215  Symbol x,
3216  Symbol y) {
3217  return Operator("where")
3218  .SetInput("condition", condition)
3219  .SetInput("x", x)
3220  .SetInput("y", y)
3221  .CreateSymbol(symbol_name);
3222 }
3223 
3249 inline Symbol smooth_l1(const std::string& symbol_name,
3250  Symbol data,
3251  mx_float scalar) {
3252  return Operator("smooth_l1")
3253  .SetParam("scalar", scalar)
3254  .SetInput("data", data)
3255  .CreateSymbol(symbol_name);
3256 }
3257 
3260 enum class Cast_storageStype {
3261  kCsr = 0,
3262  kDefault = 1,
3263  kRow_sparse = 2
3264 };
3265 
3309 inline Symbol cast_storage(const std::string& symbol_name,
3310  Symbol data,
3311  Cast_storageStype stype) {
3312  static const char *Cast_storageStypeValues[] = {
3313  "csr",
3314  "default",
3315  "row_sparse"
3316  };
3317  return Operator("cast_storage")
3318  .SetParam("stype", Cast_storageStypeValues[int(stype)])
3319  .SetInput("data", data)
3320  .CreateSymbol(symbol_name);
3321 }
3322 
3343 inline Symbol sin(const std::string& symbol_name,
3344  Symbol data) {
3345  return Operator("sin")
3346  .SetInput("data", data)
3347  .CreateSymbol(symbol_name);
3348 }
3349 
3367 inline Symbol cos(const std::string& symbol_name,
3368  Symbol data) {
3369  return Operator("cos")
3370  .SetInput("data", data)
3371  .CreateSymbol(symbol_name);
3372 }
3373 
3394 inline Symbol tan(const std::string& symbol_name,
3395  Symbol data) {
3396  return Operator("tan")
3397  .SetInput("data", data)
3398  .CreateSymbol(symbol_name);
3399 }
3400 
3422 inline Symbol arcsin(const std::string& symbol_name,
3423  Symbol data) {
3424  return Operator("arcsin")
3425  .SetInput("data", data)
3426  .CreateSymbol(symbol_name);
3427 }
3428 
3447 inline Symbol arccos(const std::string& symbol_name,
3448  Symbol data) {
3449  return Operator("arccos")
3450  .SetInput("data", data)
3451  .CreateSymbol(symbol_name);
3452 }
3453 
3474 inline Symbol arctan(const std::string& symbol_name,
3475  Symbol data) {
3476  return Operator("arctan")
3477  .SetInput("data", data)
3478  .CreateSymbol(symbol_name);
3479 }
3480 
3499 inline Symbol degrees(const std::string& symbol_name,
3500  Symbol data) {
3501  return Operator("degrees")
3502  .SetInput("data", data)
3503  .CreateSymbol(symbol_name);
3504 }
3505 
3524 inline Symbol radians(const std::string& symbol_name,
3525  Symbol data) {
3526  return Operator("radians")
3527  .SetInput("data", data)
3528  .CreateSymbol(symbol_name);
3529 }
3530 
3549 inline Symbol sinh(const std::string& symbol_name,
3550  Symbol data) {
3551  return Operator("sinh")
3552  .SetInput("data", data)
3553  .CreateSymbol(symbol_name);
3554 }
3555 
3571 inline Symbol cosh(const std::string& symbol_name,
3572  Symbol data) {
3573  return Operator("cosh")
3574  .SetInput("data", data)
3575  .CreateSymbol(symbol_name);
3576 }
3577 
3596 inline Symbol tanh(const std::string& symbol_name,
3597  Symbol data) {
3598  return Operator("tanh")
3599  .SetInput("data", data)
3600  .CreateSymbol(symbol_name);
3601 }
3602 
3619 inline Symbol arcsinh(const std::string& symbol_name,
3620  Symbol data) {
3621  return Operator("arcsinh")
3622  .SetInput("data", data)
3623  .CreateSymbol(symbol_name);
3624 }
3625 
3639 inline Symbol arccosh(const std::string& symbol_name,
3640  Symbol data) {
3641  return Operator("arccosh")
3642  .SetInput("data", data)
3643  .CreateSymbol(symbol_name);
3644 }
3645 
3662 inline Symbol arctanh(const std::string& symbol_name,
3663  Symbol data) {
3664  return Operator("arctanh")
3665  .SetInput("data", data)
3666  .CreateSymbol(symbol_name);
3667 }
3668 
3698 inline Symbol softmax(const std::string& symbol_name,
3699  Symbol data,
3700  int axis = -1) {
3701  return Operator("softmax")
3702  .SetParam("axis", axis)
3703  .SetInput("data", data)
3704  .CreateSymbol(symbol_name);
3705 }
3706 
3729 inline Symbol log_softmax(const std::string& symbol_name,
3730  Symbol data,
3731  int axis = -1) {
3732  return Operator("log_softmax")
3733  .SetParam("axis", axis)
3734  .SetInput("data", data)
3735  .CreateSymbol(symbol_name);
3736 }
3737 
3741  kBilinear = 0,
3742  kNearest = 1
3743 };
3744 
3749  kConcat = 0,
3750  kSum = 1
3751 };
3752 
3768 inline Symbol UpSampling(const std::string& symbol_name,
3769  const std::vector<Symbol>& data,
3770  uint32_t scale,
3771  UpSamplingSampleType sample_type,
3772  int num_args,
3773  uint32_t num_filter = 0,
3775  uint64_t workspace = 512) {
3776  static const char *UpSamplingSampleTypeValues[] = {
3777  "bilinear",
3778  "nearest"
3779  };
3780  static const char *UpSamplingMultiInputModeValues[] = {
3781  "concat",
3782  "sum"
3783  };
3784  return Operator("UpSampling")
3785  .SetParam("scale", scale)
3786  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
3787  .SetParam("num_args", num_args)
3788  .SetParam("num_filter", num_filter)
3789  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
3790  .SetParam("workspace", workspace)
3791 (data)
3792  .CreateSymbol(symbol_name);
3793 }
3794 
3858 inline Symbol BatchNorm(const std::string& symbol_name,
3859  Symbol data,
3860  Symbol gamma,
3861  Symbol beta,
3862  Symbol moving_mean,
3863  Symbol moving_var,
3864  double eps = 0.001,
3865  mx_float momentum = 0.9,
3866  bool fix_gamma = true,
3867  bool use_global_stats = false,
3868  bool output_mean_var = false,
3869  int axis = 1,
3870  bool cudnn_off = false) {
3871  return Operator("BatchNorm")
3872  .SetParam("eps", eps)
3873  .SetParam("momentum", momentum)
3874  .SetParam("fix_gamma", fix_gamma)
3875  .SetParam("use_global_stats", use_global_stats)
3876  .SetParam("output_mean_var", output_mean_var)
3877  .SetParam("axis", axis)
3878  .SetParam("cudnn_off", cudnn_off)
3879  .SetInput("data", data)
3880  .SetInput("gamma", gamma)
3881  .SetInput("beta", beta)
3882  .SetInput("moving_mean", moving_mean)
3883  .SetInput("moving_var", moving_var)
3884  .CreateSymbol(symbol_name);
3885 }
3886 
3890 enum class PadMode {
3891  kConstant = 0,
3892  kEdge = 1,
3893  kReflect = 2
3894 };
3895 
3992 inline Symbol Pad(const std::string& symbol_name,
3993  Symbol data,
3994  PadMode mode,
3995  Shape pad_width,
3996  double constant_value = 0) {
3997  static const char *PadModeValues[] = {
3998  "constant",
3999  "edge",
4000  "reflect"
4001  };
4002  return Operator("Pad")
4003  .SetParam("mode", PadModeValues[int(mode)])
4004  .SetParam("pad_width", pad_width)
4005  .SetParam("constant_value", constant_value)
4006  .SetInput("data", data)
4007  .CreateSymbol(symbol_name);
4008 }
4009 
4051 inline Symbol Concat(const std::string& symbol_name,
4052  const std::vector<Symbol>& data,
4053  int num_args,
4054  int dim = 1) {
4055  return Operator("Concat")
4056  .SetParam("num_args", num_args)
4057  .SetParam("dim", dim)
4058 (data)
4059  .CreateSymbol(symbol_name);
4060 }
4061 
4064 enum class LeakyReLUActType {
4065  kElu = 0,
4066  kLeaky = 1,
4067  kPrelu = 2,
4068  kRrelu = 3
4069 };
4070 
4097 inline Symbol LeakyReLU(const std::string& symbol_name,
4098  Symbol data,
4100  mx_float slope = 0.25,
4101  mx_float lower_bound = 0.125,
4102  mx_float upper_bound = 0.334) {
4103  static const char *LeakyReLUActTypeValues[] = {
4104  "elu",
4105  "leaky",
4106  "prelu",
4107  "rrelu"
4108  };
4109  return Operator("LeakyReLU")
4110  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
4111  .SetParam("slope", slope)
4112  .SetParam("lower_bound", lower_bound)
4113  .SetParam("upper_bound", upper_bound)
4114  .SetInput("data", data)
4115  .CreateSymbol(symbol_name);
4116 }
4117 
4189 inline Symbol SliceChannel(const std::string& symbol_name,
4190  Symbol data,
4191  int num_outputs,
4192  int axis = 1,
4193  bool squeeze_axis = false) {
4194  return Operator("SliceChannel")
4195  .SetParam("num_outputs", num_outputs)
4196  .SetParam("axis", axis)
4197  .SetParam("squeeze_axis", squeeze_axis)
4198  .SetInput("data", data)
4199  .CreateSymbol(symbol_name);
4200 }
4201 
4230 inline Symbol SwapAxis(const std::string& symbol_name,
4231  Symbol data,
4232  uint32_t dim1 = 0,
4233  uint32_t dim2 = 0) {
4234  return Operator("SwapAxis")
4235  .SetParam("dim1", dim1)
4236  .SetParam("dim2", dim2)
4237  .SetInput("data", data)
4238  .CreateSymbol(symbol_name);
4239 }
4240 
4296 inline Symbol BatchNorm_v1(const std::string& symbol_name,
4297  Symbol data,
4298  Symbol gamma,
4299  Symbol beta,
4300  mx_float eps = 0.001,
4301  mx_float momentum = 0.9,
4302  bool fix_gamma = true,
4303  bool use_global_stats = false,
4304  bool output_mean_var = false) {
4305  return Operator("BatchNorm_v1")
4306  .SetParam("eps", eps)
4307  .SetParam("momentum", momentum)
4308  .SetParam("fix_gamma", fix_gamma)
4309  .SetParam("use_global_stats", use_global_stats)
4310  .SetParam("output_mean_var", output_mean_var)
4311  .SetInput("data", data)
4312  .SetInput("gamma", gamma)
4313  .SetInput("beta", beta)
4314  .CreateSymbol(symbol_name);
4315 }
4316 
4354 inline Symbol softmax_cross_entropy(const std::string& symbol_name,
4355  Symbol data,
4356  Symbol label) {
4357  return Operator("softmax_cross_entropy")
4358  .SetInput("data", data)
4359  .SetInput("label", label)
4360  .CreateSymbol(symbol_name);
4361 }
4362 
4387 inline Symbol LinearRegressionOutput(const std::string& symbol_name,
4388  Symbol data,
4389  Symbol label,
4390  mx_float grad_scale = 1) {
4391  return Operator("LinearRegressionOutput")
4392  .SetParam("grad_scale", grad_scale)
4393  .SetInput("data", data)
4394  .SetInput("label", label)
4395  .CreateSymbol(symbol_name);
4396 }
4397 
4423 inline Symbol MAERegressionOutput(const std::string& symbol_name,
4424  Symbol data,
4425  Symbol label,
4426  mx_float grad_scale = 1) {
4427  return Operator("MAERegressionOutput")
4428  .SetParam("grad_scale", grad_scale)
4429  .SetInput("data", data)
4430  .SetInput("label", label)
4431  .CreateSymbol(symbol_name);
4432 }
4433 
4459 inline Symbol LogisticRegressionOutput(const std::string& symbol_name,
4460  Symbol data,
4461  Symbol label,
4462  mx_float grad_scale = 1) {
4463  return Operator("LogisticRegressionOutput")
4464  .SetParam("grad_scale", grad_scale)
4465  .SetInput("data", data)
4466  .SetInput("label", label)
4467  .CreateSymbol(symbol_name);
4468 }
4469 
4479 inline Symbol IdentityAttachKLSparseReg(const std::string& symbol_name,
4480  Symbol data,
4481  mx_float sparseness_target = 0.1,
4482  mx_float penalty = 0.001,
4483  mx_float momentum = 0.9) {
4484  return Operator("IdentityAttachKLSparseReg")
4485  .SetParam("sparseness_target", sparseness_target)
4486  .SetParam("penalty", penalty)
4487  .SetParam("momentum", momentum)
4488  .SetInput("data", data)
4489  .CreateSymbol(symbol_name);
4490 }
4491 
4519 inline Symbol signsgd_update(const std::string& symbol_name,
4520  Symbol weight,
4521  Symbol grad,
4522  mx_float lr,
4523  mx_float wd = 0,
4524  mx_float rescale_grad = 1,
4525  mx_float clip_gradient = -1) {
4526  return Operator("signsgd_update")
4527  .SetParam("lr", lr)
4528  .SetParam("wd", wd)
4529  .SetParam("rescale_grad", rescale_grad)
4530  .SetParam("clip_gradient", clip_gradient)
4531  .SetInput("weight", weight)
4532  .SetInput("grad", grad)
4533  .CreateSymbol(symbol_name);
4534 }
4535 
4570 inline Symbol signum_update(const std::string& symbol_name,
4571  Symbol weight,
4572  Symbol grad,
4573  Symbol mom,
4574  mx_float lr,
4575  mx_float momentum = 0,
4576  mx_float wd = 0,
4577  mx_float rescale_grad = 1,
4578  mx_float clip_gradient = -1,
4579  mx_float wd_lh = 0) {
4580  return Operator("signum_update")
4581  .SetParam("lr", lr)
4582  .SetParam("momentum", momentum)
4583  .SetParam("wd", wd)
4584  .SetParam("rescale_grad", rescale_grad)
4585  .SetParam("clip_gradient", clip_gradient)
4586  .SetParam("wd_lh", wd_lh)
4587  .SetInput("weight", weight)
4588  .SetInput("grad", grad)
4589  .SetInput("mom", mom)
4590  .CreateSymbol(symbol_name);
4591 }
4592 
4620 inline Symbol sgd_update(const std::string& symbol_name,
4621  Symbol weight,
4622  Symbol grad,
4623  mx_float lr,
4624  mx_float wd = 0,
4625  mx_float rescale_grad = 1,
4626  mx_float clip_gradient = -1) {
4627  return Operator("sgd_update")
4628  .SetParam("lr", lr)
4629  .SetParam("wd", wd)
4630  .SetParam("rescale_grad", rescale_grad)
4631  .SetParam("clip_gradient", clip_gradient)
4632  .SetInput("weight", weight)
4633  .SetInput("grad", grad)
4634  .CreateSymbol(symbol_name);
4635 }
4636 
4682 inline Symbol sgd_mom_update(const std::string& symbol_name,
4683  Symbol weight,
4684  Symbol grad,
4685  Symbol mom,
4686  mx_float lr,
4687  mx_float momentum = 0,
4688  mx_float wd = 0,
4689  mx_float rescale_grad = 1,
4690  mx_float clip_gradient = -1) {
4691  return Operator("sgd_mom_update")
4692  .SetParam("lr", lr)
4693  .SetParam("momentum", momentum)
4694  .SetParam("wd", wd)
4695  .SetParam("rescale_grad", rescale_grad)
4696  .SetParam("clip_gradient", clip_gradient)
4697  .SetInput("weight", weight)
4698  .SetInput("grad", grad)
4699  .SetInput("mom", mom)
4700  .CreateSymbol(symbol_name);
4701 }
4702 
4717 inline Symbol mp_sgd_update(const std::string& symbol_name,
4718  Symbol weight,
4719  Symbol grad,
4720  Symbol weight32,
4721  mx_float lr,
4722  mx_float wd = 0,
4723  mx_float rescale_grad = 1,
4724  mx_float clip_gradient = -1) {
4725  return Operator("mp_sgd_update")
4726  .SetParam("lr", lr)
4727  .SetParam("wd", wd)
4728  .SetParam("rescale_grad", rescale_grad)
4729  .SetParam("clip_gradient", clip_gradient)
4730  .SetInput("weight", weight)
4731  .SetInput("grad", grad)
4732  .SetInput("weight32", weight32)
4733  .CreateSymbol(symbol_name);
4734 }
4735 
4752 inline Symbol mp_sgd_mom_update(const std::string& symbol_name,
4753  Symbol weight,
4754  Symbol grad,
4755  Symbol mom,
4756  Symbol weight32,
4757  mx_float lr,
4758  mx_float momentum = 0,
4759  mx_float wd = 0,
4760  mx_float rescale_grad = 1,
4761  mx_float clip_gradient = -1) {
4762  return Operator("mp_sgd_mom_update")
4763  .SetParam("lr", lr)
4764  .SetParam("momentum", momentum)
4765  .SetParam("wd", wd)
4766  .SetParam("rescale_grad", rescale_grad)
4767  .SetParam("clip_gradient", clip_gradient)
4768  .SetInput("weight", weight)
4769  .SetInput("grad", grad)
4770  .SetInput("mom", mom)
4771  .SetInput("weight32", weight32)
4772  .CreateSymbol(symbol_name);
4773 }
4774 
4809 inline Symbol ftml_update(const std::string& symbol_name,
4810  Symbol weight,
4811  Symbol grad,
4812  Symbol d,
4813  Symbol v,
4814  Symbol z,
4815  mx_float lr,
4816  mx_float beta1 = 0.9,
4817  mx_float beta2 = 0.999,
4818  mx_float epsilon = 1e-08,
4819  mx_float wd = 0,
4820  mx_float rescale_grad = 1,
4821  mx_float clip_gradient = -1) {
4822  return Operator("ftml_update")
4823  .SetParam("lr", lr)
4824  .SetParam("beta1", beta1)
4825  .SetParam("beta2", beta2)
4826  .SetParam("epsilon", epsilon)
4827  .SetParam("wd", wd)
4828  .SetParam("rescale_grad", rescale_grad)
4829  .SetParam("clip_gradient", clip_gradient)
4830  .SetInput("weight", weight)
4831  .SetInput("grad", grad)
4832  .SetInput("d", d)
4833  .SetInput("v", v)
4834  .SetInput("z", z)
4835  .CreateSymbol(symbol_name);
4836 }
4837 
4885 inline Symbol adam_update(const std::string& symbol_name,
4886  Symbol weight,
4887  Symbol grad,
4888  Symbol mean,
4889  Symbol var,
4890  mx_float lr,
4891  mx_float beta1 = 0.9,
4892  mx_float beta2 = 0.999,
4893  mx_float epsilon = 1e-08,
4894  mx_float wd = 0,
4895  mx_float rescale_grad = 1,
4896  mx_float clip_gradient = -1) {
4897  return Operator("adam_update")
4898  .SetParam("lr", lr)
4899  .SetParam("beta1", beta1)
4900  .SetParam("beta2", beta2)
4901  .SetParam("epsilon", epsilon)
4902  .SetParam("wd", wd)
4903  .SetParam("rescale_grad", rescale_grad)
4904  .SetParam("clip_gradient", clip_gradient)
4905  .SetInput("weight", weight)
4906  .SetInput("grad", grad)
4907  .SetInput("mean", mean)
4908  .SetInput("var", var)
4909  .CreateSymbol(symbol_name);
4910 }
4911 
4965 inline Symbol rmsprop_update(const std::string& symbol_name,
4966  Symbol weight,
4967  Symbol grad,
4968  Symbol n,
4969  mx_float lr,
4970  mx_float gamma1 = 0.95,
4971  mx_float epsilon = 1e-08,
4972  mx_float wd = 0,
4973  mx_float rescale_grad = 1,
4974  mx_float clip_gradient = -1,
4975  mx_float clip_weights = -1) {
4976  return Operator("rmsprop_update")
4977  .SetParam("lr", lr)
4978  .SetParam("gamma1", gamma1)
4979  .SetParam("epsilon", epsilon)
4980  .SetParam("wd", wd)
4981  .SetParam("rescale_grad", rescale_grad)
4982  .SetParam("clip_gradient", clip_gradient)
4983  .SetParam("clip_weights", clip_weights)
4984  .SetInput("weight", weight)
4985  .SetInput("grad", grad)
4986  .SetInput("n", n)
4987  .CreateSymbol(symbol_name);
4988 }
4989 
5035 inline Symbol rmspropalex_update(const std::string& symbol_name,
5036  Symbol weight,
5037  Symbol grad,
5038  Symbol n,
5039  Symbol g,
5040  Symbol delta,
5041  mx_float lr,
5042  mx_float gamma1 = 0.95,
5043  mx_float gamma2 = 0.9,
5044  mx_float epsilon = 1e-08,
5045  mx_float wd = 0,
5046  mx_float rescale_grad = 1,
5047  mx_float clip_gradient = -1,
5048  mx_float clip_weights = -1) {
5049  return Operator("rmspropalex_update")
5050  .SetParam("lr", lr)
5051  .SetParam("gamma1", gamma1)
5052  .SetParam("gamma2", gamma2)
5053  .SetParam("epsilon", epsilon)
5054  .SetParam("wd", wd)
5055  .SetParam("rescale_grad", rescale_grad)
5056  .SetParam("clip_gradient", clip_gradient)
5057  .SetParam("clip_weights", clip_weights)
5058  .SetInput("weight", weight)
5059  .SetInput("grad", grad)
5060  .SetInput("n", n)
5061  .SetInput("g", g)
5062  .SetInput("delta", delta)
5063  .CreateSymbol(symbol_name);
5064 }
5065 
5105 inline Symbol ftrl_update(const std::string& symbol_name,
5106  Symbol weight,
5107  Symbol grad,
5108  Symbol z,
5109  Symbol n,
5110  mx_float lr,
5111  mx_float lamda1 = 0.01,
5112  mx_float beta = 1,
5113  mx_float wd = 0,
5114  mx_float rescale_grad = 1,
5115  mx_float clip_gradient = -1) {
5116  return Operator("ftrl_update")
5117  .SetParam("lr", lr)
5118  .SetParam("lamda1", lamda1)
5119  .SetParam("beta", beta)
5120  .SetParam("wd", wd)
5121  .SetParam("rescale_grad", rescale_grad)
5122  .SetParam("clip_gradient", clip_gradient)
5123  .SetInput("weight", weight)
5124  .SetInput("grad", grad)
5125  .SetInput("z", z)
5126  .SetInput("n", n)
5127  .CreateSymbol(symbol_name);
5128 }
5129 
5132 enum class PoolingPoolType {
5133  kAvg = 0,
5134  kMax = 1,
5135  kSum = 2
5136 };
5137 
5141  kFull = 0,
5142  kValid = 1
5143 };
5144 
5198 inline Symbol Pooling(const std::string& symbol_name,
5199  Symbol data,
5200  Shape kernel,
5201  PoolingPoolType pool_type,
5202  bool global_pool = false,
5203  bool cudnn_off = false,
5205  Shape stride = Shape(),
5206  Shape pad = Shape()) {
5207  static const char *PoolingPoolTypeValues[] = {
5208  "avg",
5209  "max",
5210  "sum"
5211  };
5212  static const char *PoolingPoolingConventionValues[] = {
5213  "full",
5214  "valid"
5215  };
5216  return Operator("Pooling")
5217  .SetParam("kernel", kernel)
5218  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
5219  .SetParam("global_pool", global_pool)
5220  .SetParam("cudnn_off", cudnn_off)
5221  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
5222  .SetParam("stride", stride)
5223  .SetParam("pad", pad)
5224  .SetInput("data", data)
5225  .CreateSymbol(symbol_name);
5226 }
5227 
5231  kNone = 0,
5232  kFastest = 1,
5233  kLimited_workspace = 2,
5234  kOff = 3
5235 };
5236 
5240  kNone = 0,
5241  kNCDHW = 1,
5242  kNCHW = 2,
5243  kNCW = 3,
5244  kNDHWC = 4,
5245  kNHWC = 5
5246 };
5247 
5274 inline Symbol Deconvolution(const std::string& symbol_name,
5275  Symbol data,
5276  Symbol weight,
5277  Symbol bias,
5278  Shape kernel,
5279  uint32_t num_filter,
5280  Shape stride = Shape(),
5281  Shape dilate = Shape(),
5282  Shape pad = Shape(),
5283  Shape adj = Shape(),
5284  Shape target_shape = Shape(),
5285  uint32_t num_group = 1,
5286  uint64_t workspace = 512,
5287  bool no_bias = true,
5289  bool cudnn_off = false,
5291  static const char *DeconvolutionCudnnTuneValues[] = {
5292  "None",
5293  "fastest",
5294  "limited_workspace",
5295  "off"
5296  };
5297  static const char *DeconvolutionLayoutValues[] = {
5298  "None",
5299  "NCDHW",
5300  "NCHW",
5301  "NCW",
5302  "NDHWC",
5303  "NHWC"
5304  };
5305  return Operator("Deconvolution")
5306  .SetParam("kernel", kernel)
5307  .SetParam("num_filter", num_filter)
5308  .SetParam("stride", stride)
5309  .SetParam("dilate", dilate)
5310  .SetParam("pad", pad)
5311  .SetParam("adj", adj)
5312  .SetParam("target_shape", target_shape)
5313  .SetParam("num_group", num_group)
5314  .SetParam("workspace", workspace)
5315  .SetParam("no_bias", no_bias)
5316  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
5317  .SetParam("cudnn_off", cudnn_off)
5318  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
5319  .SetInput("data", data)
5320  .SetInput("weight", weight)
5321  .SetInput("bias", bias)
5322  .CreateSymbol(symbol_name);
5323 }
5324 
5327 enum class ActivationActType {
5328  kRelu = 0,
5329  kSigmoid = 1,
5330  kSoftrelu = 2,
5331  kTanh = 3
5332 };
5333 
5352 inline Symbol Activation(const std::string& symbol_name,
5353  Symbol data,
5354  ActivationActType act_type) {
5355  static const char *ActivationActTypeValues[] = {
5356  "relu",
5357  "sigmoid",
5358  "softrelu",
5359  "tanh"
5360  };
5361  return Operator("Activation")
5362  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
5363  .SetInput("data", data)
5364  .CreateSymbol(symbol_name);
5365 }
5366 
5370  kNone = 0,
5371  kFastest = 1,
5372  kLimited_workspace = 2,
5373  kOff = 3
5374 };
5375 
5379 enum class ConvolutionLayout {
5380  kNone = 0,
5381  kNCDHW = 1,
5382  kNCHW = 2,
5383  kNCW = 3,
5384  kNDHWC = 4,
5385  kNHWC = 5
5386 };
5387 
5482 inline Symbol Convolution(const std::string& symbol_name,
5483  Symbol data,
5484  Symbol weight,
5485  Symbol bias,
5486  Shape kernel,
5487  uint32_t num_filter,
5488  Shape stride = Shape(),
5489  Shape dilate = Shape(),
5490  Shape pad = Shape(),
5491  uint32_t num_group = 1,
5492  uint64_t workspace = 1024,
5493  bool no_bias = false,
5495  bool cudnn_off = false,
5497  static const char *ConvolutionCudnnTuneValues[] = {
5498  "None",
5499  "fastest",
5500  "limited_workspace",
5501  "off"
5502  };
5503  static const char *ConvolutionLayoutValues[] = {
5504  "None",
5505  "NCDHW",
5506  "NCHW",
5507  "NCW",
5508  "NDHWC",
5509  "NHWC"
5510  };
5511  return Operator("Convolution")
5512  .SetParam("kernel", kernel)
5513  .SetParam("num_filter", num_filter)
5514  .SetParam("stride", stride)
5515  .SetParam("dilate", dilate)
5516  .SetParam("pad", pad)
5517  .SetParam("num_group", num_group)
5518  .SetParam("workspace", workspace)
5519  .SetParam("no_bias", no_bias)
5520  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
5521  .SetParam("cudnn_off", cudnn_off)
5522  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
5523  .SetInput("data", data)
5524  .SetInput("weight", weight)
5525  .SetInput("bias", bias)
5526  .CreateSymbol(symbol_name);
5527 }
5528 
5531 enum class DropoutMode {
5532  kAlways = 0,
5533  kTraining = 1
5534 };
5535 
5575 inline Symbol Dropout(const std::string& symbol_name,
5576  Symbol data,
5577  mx_float p = 0.5,
5579  static const char *DropoutModeValues[] = {
5580  "always",
5581  "training"
5582  };
5583  return Operator("Dropout")
5584  .SetParam("p", p)
5585  .SetParam("mode", DropoutModeValues[int(mode)])
5586  .SetInput("data", data)
5587  .CreateSymbol(symbol_name);
5588 }
5589 
5594  kChannel = 0,
5595  kInstance = 1
5596 };
5597 
5631 inline Symbol SoftmaxActivation(const std::string& symbol_name,
5632  Symbol data,
5634  static const char *SoftmaxActivationModeValues[] = {
5635  "channel",
5636  "instance"
5637  };
5638  return Operator("SoftmaxActivation")
5639  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
5640  .SetInput("data", data)
5641  .CreateSymbol(symbol_name);
5642 }
5643 
5677 inline Symbol FullyConnected(const std::string& symbol_name,
5678  Symbol data,
5679  Symbol weight,
5680  Symbol bias,
5681  int num_hidden,
5682  bool no_bias = false,
5683  bool flatten = true) {
5684  return Operator("FullyConnected")
5685  .SetParam("num_hidden", num_hidden)
5686  .SetParam("no_bias", no_bias)
5687  .SetParam("flatten", flatten)
5688  .SetInput("data", data)
5689  .SetInput("weight", weight)
5690  .SetInput("bias", bias)
5691  .CreateSymbol(symbol_name);
5692 }
5693 
5744 inline Symbol InstanceNorm(const std::string& symbol_name,
5745  Symbol data,
5746  Symbol gamma,
5747  Symbol beta,
5748  mx_float eps = 0.001) {
5749  return Operator("InstanceNorm")
5750  .SetParam("eps", eps)
5751  .SetInput("data", data)
5752  .SetInput("gamma", gamma)
5753  .SetInput("beta", beta)
5754  .CreateSymbol(symbol_name);
5755 }
5756 
5761  kAffine = 0,
5762  kWarp = 1
5763 };
5764 
5775 inline Symbol GridGenerator(const std::string& symbol_name,
5776  Symbol data,
5777  GridGeneratorTransformType transform_type,
5778  Shape target_shape = Shape(0,0)) {
5779  static const char *GridGeneratorTransformTypeValues[] = {
5780  "affine",
5781  "warp"
5782  };
5783  return Operator("GridGenerator")
5784  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
5785  .SetParam("target_shape", target_shape)
5786  .SetInput("data", data)
5787  .CreateSymbol(symbol_name);
5788 }
5789 
5793  kAvg = 0,
5794  kMax = 1,
5795  kSum = 2
5796 };
5797 
5801  kFull = 0,
5802  kValid = 1
5803 };
5804 
5856 inline Symbol Pooling_v1(const std::string& symbol_name,
5857  Symbol data,
5858  Shape kernel,
5859  Pooling_v1PoolType pool_type,
5860  bool global_pool = false,
5862  Shape stride = Shape(),
5863  Shape pad = Shape()) {
5864  static const char *Pooling_v1PoolTypeValues[] = {
5865  "avg",
5866  "max",
5867  "sum"
5868  };
5869  static const char *Pooling_v1PoolingConventionValues[] = {
5870  "full",
5871  "valid"
5872  };
5873  return Operator("Pooling_v1")
5874  .SetParam("kernel", kernel)
5875  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
5876  .SetParam("global_pool", global_pool)
5877  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
5878  .SetParam("stride", stride)
5879  .SetParam("pad", pad)
5880  .SetInput("data", data)
5881  .CreateSymbol(symbol_name);
5882 }
5883 
5886 enum class RNNMode {
5887  kGru = 0,
5888  kLstm = 1,
5889  kRnn_relu = 2,
5890  kRnn_tanh = 3
5891 };
5892 
5908 inline Symbol RNN(const std::string& symbol_name,
5909  Symbol data,
5910  Symbol parameters,
5911  Symbol state,
5912  Symbol state_cell,
5913  uint32_t state_size,
5914  uint32_t num_layers,
5915  RNNMode mode,
5916  bool bidirectional = false,
5917  mx_float p = 0,
5918  bool state_outputs = false) {
5919  static const char *RNNModeValues[] = {
5920  "gru",
5921  "lstm",
5922  "rnn_relu",
5923  "rnn_tanh"
5924  };
5925  return Operator("RNN")
5926  .SetParam("state_size", state_size)
5927  .SetParam("num_layers", num_layers)
5928  .SetParam("mode", RNNModeValues[int(mode)])
5929  .SetParam("bidirectional", bidirectional)
5930  .SetParam("p", p)
5931  .SetParam("state_outputs", state_outputs)
5932  .SetInput("data", data)
5933  .SetInput("parameters", parameters)
5934  .SetInput("state", state)
5935  .SetInput("state_cell", state_cell)
5936  .CreateSymbol(symbol_name);
5937 }
5938 
5949  kNone = 0,
5950  kFastest = 1,
5951  kLimited_workspace = 2,
5952  kOff = 3
5953 };
5954 
5959  kNone = 0,
5960  kNCDHW = 1,
5961  kNCHW = 2,
5962  kNDHWC = 3,
5963  kNHWC = 4
5964 };
5965 
5994 inline Symbol Convolution_v1(const std::string& symbol_name,
5995  Symbol data,
5996  Symbol weight,
5997  Symbol bias,
5998  Shape kernel,
5999  uint32_t num_filter,
6000  Shape stride = Shape(),
6001  Shape dilate = Shape(),
6002  Shape pad = Shape(),
6003  uint32_t num_group = 1,
6004  uint64_t workspace = 1024,
6005  bool no_bias = false,
6007  bool cudnn_off = false,
6009  static const char *Convolution_v1CudnnTuneValues[] = {
6010  "None",
6011  "fastest",
6012  "limited_workspace",
6013  "off"
6014  };
6015  static const char *Convolution_v1LayoutValues[] = {
6016  "None",
6017  "NCDHW",
6018  "NCHW",
6019  "NDHWC",
6020  "NHWC"
6021  };
6022  return Operator("Convolution_v1")
6023  .SetParam("kernel", kernel)
6024  .SetParam("num_filter", num_filter)
6025  .SetParam("stride", stride)
6026  .SetParam("dilate", dilate)
6027  .SetParam("pad", pad)
6028  .SetParam("num_group", num_group)
6029  .SetParam("workspace", workspace)
6030  .SetParam("no_bias", no_bias)
6031  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
6032  .SetParam("cudnn_off", cudnn_off)
6033  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
6034  .SetInput("data", data)
6035  .SetInput("weight", weight)
6036  .SetInput("bias", bias)
6037  .CreateSymbol(symbol_name);
6038 }
6039 
6060 inline Symbol Crop(const std::string& symbol_name,
6061  const std::vector<Symbol>& data,
6062  int num_args,
6063  Shape offset = Shape(0,0),
6064  Shape h_w = Shape(0,0),
6065  bool center_crop = false) {
6066  return Operator("Crop")
6067  .SetParam("num_args", num_args)
6068  .SetParam("offset", offset)
6069  .SetParam("h_w", h_w)
6070  .SetParam("center_crop", center_crop)
6071 (data)
6072  .CreateSymbol(symbol_name);
6073 }
6074 
6151 inline Symbol SequenceReverse(const std::string& symbol_name,
6152  Symbol data,
6153  Symbol sequence_length,
6154  bool use_sequence_length = false,
6155  int axis = 0) {
6156  return Operator("SequenceReverse")
6157  .SetParam("use_sequence_length", use_sequence_length)
6158  .SetParam("axis", axis)
6159  .SetInput("data", data)
6160  .SetInput("sequence_length", sequence_length)
6161  .CreateSymbol(symbol_name);
6162 }
6163 
6167  kAffine = 0
6168 };
6169 
6173  kBilinear = 0
6174 };
6175 
6186 inline Symbol SpatialTransformer(const std::string& symbol_name,
6187  Symbol data,
6188  Symbol loc,
6189  SpatialTransformerTransformType transform_type,
6190  SpatialTransformerSamplerType sampler_type,
6191  Shape target_shape = Shape(0,0)) {
6192  static const char *SpatialTransformerTransformTypeValues[] = {
6193  "affine"
6194  };
6195  static const char *SpatialTransformerSamplerTypeValues[] = {
6196  "bilinear"
6197  };
6198  return Operator("SpatialTransformer")
6199  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
6200  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
6201  .SetParam("target_shape", target_shape)
6202  .SetInput("data", data)
6203  .SetInput("loc", loc)
6204  .CreateSymbol(symbol_name);
6205 }
6206 
6262 inline Symbol SequenceLast(const std::string& symbol_name,
6263  Symbol data,
6264  Symbol sequence_length,
6265  bool use_sequence_length = false,
6266  int axis = 0) {
6267  return Operator("SequenceLast")
6268  .SetParam("use_sequence_length", use_sequence_length)
6269  .SetParam("axis", axis)
6270  .SetInput("data", data)
6271  .SetInput("sequence_length", sequence_length)
6272  .CreateSymbol(symbol_name);
6273 }
6274 
6278  kBatch = 0,
6279  kNull = 1,
6280  kValid = 2
6281 };
6282 
6378 inline Symbol SoftmaxOutput(const std::string& symbol_name,
6379  Symbol data,
6380  Symbol label,
6381  mx_float grad_scale = 1,
6382  mx_float ignore_label = -1,
6383  bool multi_output = false,
6384  bool use_ignore = false,
6385  bool preserve_shape = false,
6387  bool out_grad = false,
6388  mx_float smooth_alpha = 0) {
6389  static const char *SoftmaxOutputNormalizationValues[] = {
6390  "batch",
6391  "null",
6392  "valid"
6393  };
6394  return Operator("SoftmaxOutput")
6395  .SetParam("grad_scale", grad_scale)
6396  .SetParam("ignore_label", ignore_label)
6397  .SetParam("multi_output", multi_output)
6398  .SetParam("use_ignore", use_ignore)
6399  .SetParam("preserve_shape", preserve_shape)
6400  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
6401  .SetParam("out_grad", out_grad)
6402  .SetParam("smooth_alpha", smooth_alpha)
6403  .SetInput("data", data)
6404  .SetInput("label", label)
6405  .CreateSymbol(symbol_name);
6406 }
6407 
6411  kBatch = 0,
6412  kNull = 1,
6413  kValid = 2
6414 };
6415 
6443 inline Symbol Softmax(const std::string& symbol_name,
6444  Symbol data,
6445  mx_float grad_scale = 1,
6446  mx_float ignore_label = -1,
6447  bool multi_output = false,
6448  bool use_ignore = false,
6449  bool preserve_shape = false,
6451  bool out_grad = false,
6452  mx_float smooth_alpha = 0) {
6453  static const char *SoftmaxNormalizationValues[] = {
6454  "batch",
6455  "null",
6456  "valid"
6457  };
6458  return Operator("Softmax")
6459  .SetParam("grad_scale", grad_scale)
6460  .SetParam("ignore_label", ignore_label)
6461  .SetParam("multi_output", multi_output)
6462  .SetParam("use_ignore", use_ignore)
6463  .SetParam("preserve_shape", preserve_shape)
6464  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
6465  .SetParam("out_grad", out_grad)
6466  .SetParam("smooth_alpha", smooth_alpha)
6467  .SetInput("data", data)
6468  .CreateSymbol(symbol_name);
6469 }
6470 
6551 inline Symbol BilinearSampler(const std::string& symbol_name,
6552  Symbol data,
6553  Symbol grid) {
6554  return Operator("BilinearSampler")
6555  .SetInput("data", data)
6556  .SetInput("grid", grid)
6557  .CreateSymbol(symbol_name);
6558 }
6559 
6616 inline Symbol ROIPooling(const std::string& symbol_name,
6617  Symbol data,
6618  Symbol rois,
6619  Shape pooled_size,
6620  mx_float spatial_scale) {
6621  return Operator("ROIPooling")
6622  .SetParam("pooled_size", pooled_size)
6623  .SetParam("spatial_scale", spatial_scale)
6624  .SetInput("data", data)
6625  .SetInput("rois", rois)
6626  .CreateSymbol(symbol_name);
6627 }
6628 
6632  kChannel = 0,
6633  kInstance = 1,
6634  kSpatial = 2
6635 };
6636 
6699 inline Symbol L2Normalization(const std::string& symbol_name,
6700  Symbol data,
6701  mx_float eps = 1e-10,
6703  static const char *L2NormalizationModeValues[] = {
6704  "channel",
6705  "instance",
6706  "spatial"
6707  };
6708  return Operator("L2Normalization")
6709  .SetParam("eps", eps)
6710  .SetParam("mode", L2NormalizationModeValues[int(mode)])
6711  .SetInput("data", data)
6712  .CreateSymbol(symbol_name);
6713 }
6714 
6720  kBatch = 0,
6721  kNull = 1,
6722  kValid = 2
6723 };
6724 
6759 inline Symbol MakeLoss(const std::string& symbol_name,
6760  Symbol data,
6761  mx_float grad_scale = 1,
6762  mx_float valid_thresh = 0,
6764  static const char *MakeLossNormalizationValues[] = {
6765  "batch",
6766  "null",
6767  "valid"
6768  };
6769  return Operator("MakeLoss")
6770  .SetParam("grad_scale", grad_scale)
6771  .SetParam("valid_thresh", valid_thresh)
6772  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
6773  .SetInput("data", data)
6774  .CreateSymbol(symbol_name);
6775 }
6776 
6792 inline Symbol SVMOutput(const std::string& symbol_name,
6793  Symbol data,
6794  Symbol label,
6795  mx_float margin = 1,
6796  mx_float regularization_coefficient = 1,
6797  bool use_linear = false) {
6798  return Operator("SVMOutput")
6799  .SetParam("margin", margin)
6800  .SetParam("regularization_coefficient", regularization_coefficient)
6801  .SetParam("use_linear", use_linear)
6802  .SetInput("data", data)
6803  .SetInput("label", label)
6804  .CreateSymbol(symbol_name);
6805 }
6806 
6834 inline Symbol LRN(const std::string& symbol_name,
6835  Symbol data,
6836  uint32_t nsize,
6837  mx_float alpha = 0.0001,
6838  mx_float beta = 0.75,
6839  mx_float knorm = 2) {
6840  return Operator("LRN")
6841  .SetParam("nsize", nsize)
6842  .SetParam("alpha", alpha)
6843  .SetParam("beta", beta)
6844  .SetParam("knorm", knorm)
6845  .SetInput("data", data)
6846  .CreateSymbol(symbol_name);
6847 }
6848 
6897 inline Symbol Correlation(const std::string& symbol_name,
6898  Symbol data1,
6899  Symbol data2,
6900  uint32_t kernel_size = 1,
6901  uint32_t max_displacement = 1,
6902  uint32_t stride1 = 1,
6903  uint32_t stride2 = 1,
6904  uint32_t pad_size = 0,
6905  bool is_multiply = true) {
6906  return Operator("Correlation")
6907  .SetParam("kernel_size", kernel_size)
6908  .SetParam("max_displacement", max_displacement)
6909  .SetParam("stride1", stride1)
6910  .SetParam("stride2", stride2)
6911  .SetParam("pad_size", pad_size)
6912  .SetParam("is_multiply", is_multiply)
6913  .SetInput("data1", data1)
6914  .SetInput("data2", data2)
6915  .CreateSymbol(symbol_name);
6916 }
6917 
6996 inline Symbol SequenceMask(const std::string& symbol_name,
6997  Symbol data,
6998  Symbol sequence_length,
6999  bool use_sequence_length = false,
7000  mx_float value = 0,
7001  int axis = 0) {
7002  return Operator("SequenceMask")
7003  .SetParam("use_sequence_length", use_sequence_length)
7004  .SetParam("value", value)
7005  .SetParam("axis", axis)
7006  .SetInput("data", data)
7007  .SetInput("sequence_length", sequence_length)
7008  .CreateSymbol(symbol_name);
7009 }
7010 
7019 inline Symbol choose_element_0index(const std::string& symbol_name,
7020  Symbol lhs,
7021  Symbol rhs) {
7022  return Operator("choose_element_0index")
7023  .SetInput("lhs", lhs)
7024  .SetInput("rhs", rhs)
7025  .CreateSymbol(symbol_name);
7026 }
7027 
7037 inline Symbol fill_element_0index(const std::string& symbol_name,
7038  Symbol lhs,
7039  Symbol mhs,
7040  Symbol rhs) {
7041  return Operator("fill_element_0index")
7042  .SetInput("lhs", lhs)
7043  .SetInput("mhs", mhs)
7044  .SetInput("rhs", rhs)
7045  .CreateSymbol(symbol_name);
7046 }
7047 
7087 inline Symbol khatri_rao(const std::vector<Symbol>& args) {
7088  return Operator("khatri_rao")
7089 (args)
7090  .CreateSymbol();
7091 }
7092 
7107 inline Symbol Custom(const std::vector<Symbol>& data,
7108  const std::string& op_type) {
7109  return Operator("Custom")
7110 (data)
7111  .CreateSymbol();
7112 }
7113 
7136  Symbol rhs) {
7137  return Operator("broadcast_power")
7138  .SetInput("lhs", lhs)
7139  .SetInput("rhs", rhs)
7140  .CreateSymbol();
7141 }
7142 
7167  Symbol rhs) {
7168  return Operator("broadcast_maximum")
7169  .SetInput("lhs", lhs)
7170  .SetInput("rhs", rhs)
7171  .CreateSymbol();
7172 }
7173 
7198  Symbol rhs) {
7199  return Operator("broadcast_minimum")
7200  .SetInput("lhs", lhs)
7201  .SetInput("rhs", rhs)
7202  .CreateSymbol();
7203 }
7204 
7235  Symbol rhs) {
7236  return Operator("broadcast_hypot")
7237  .SetInput("lhs", lhs)
7238  .SetInput("rhs", rhs)
7239  .CreateSymbol();
7240 }
7241 
7315 inline Symbol Reshape(Symbol data,
7316  Shape shape = Shape(),
7317  bool reverse = false,
7318  Shape target_shape = Shape(),
7319  bool keep_highest = false) {
7320  return Operator("Reshape")
7321  .SetParam("shape", shape)
7322  .SetParam("reverse", reverse)
7323  .SetParam("target_shape", target_shape)
7324  .SetParam("keep_highest", keep_highest)
7325  .SetInput("data", data)
7326  .CreateSymbol();
7327 }
7328 
7358 inline Symbol Flatten(Symbol data) {
7359  return Operator("Flatten")
7360  .SetInput("data", data)
7361  .CreateSymbol();
7362 }
7363 
7400  Shape axes = Shape()) {
7401  return Operator("transpose")
7402  .SetParam("axes", axes)
7403  .SetInput("data", data)
7404  .CreateSymbol();
7405 }
7406 
7422  int axis) {
7423  return Operator("expand_dims")
7424  .SetParam("axis", axis)
7425  .SetInput("data", data)
7426  .CreateSymbol();
7427 }
7428 
7483 inline Symbol slice(Symbol data,
7484  Shape begin,
7485  Shape end,
7486  Shape step = Shape()) {
7487  return Operator("slice")
7488  .SetParam("begin", begin)
7489  .SetParam("end", end)
7490  .SetParam("step", step)
7491  .SetInput("data", data)
7492  .CreateSymbol();
7493 }
7494 
7527  int axis,
7528  int begin,
7529  dmlc::optional<int> end) {
7530  return Operator("slice_axis")
7531  .SetParam("axis", axis)
7532  .SetParam("begin", begin)
7533  .SetParam("end", end)
7534  .SetInput("data", data)
7535  .CreateSymbol();
7536 }
7537 
7571 inline Symbol clip(Symbol data,
7572  mx_float a_min,
7573  mx_float a_max) {
7574  return Operator("clip")
7575  .SetParam("a_min", a_min)
7576  .SetParam("a_max", a_max)
7577  .SetInput("data", data)
7578  .CreateSymbol();
7579 }
7580 
7614 inline Symbol repeat(Symbol data,
7615  int repeats,
7616  dmlc::optional<int> axis = dmlc::optional<int>()) {
7617  return Operator("repeat")
7618  .SetParam("repeats", repeats)
7619  .SetParam("axis", axis)
7620  .SetInput("data", data)
7621  .CreateSymbol();
7622 }
7623 
7668 inline Symbol tile(Symbol data,
7669  Shape reps) {
7670  return Operator("tile")
7671  .SetParam("reps", reps)
7672  .SetInput("data", data)
7673  .CreateSymbol();
7674 }
7675 
7698 inline Symbol reverse(Symbol data,
7699  Shape axis) {
7700  return Operator("reverse")
7701  .SetParam("axis", axis)
7702  .SetInput("data", data)
7703  .CreateSymbol();
7704 }
7705 
7728 inline Symbol stack(const std::vector<Symbol>& data,
7729  int num_args,
7730  int axis = 0) {
7731  return Operator("stack")
7732  .SetParam("num_args", num_args)
7733  .SetParam("axis", axis)
7734 (data)
7735  .CreateSymbol();
7736 }
7737 
7760 inline Symbol zeros_like(Symbol data) {
7761  return Operator("zeros_like")
7762  .SetInput("data", data)
7763  .CreateSymbol();
7764 }
7765 
7782 inline Symbol ones_like(Symbol data) {
7783  return Operator("ones_like")
7784  .SetInput("data", data)
7785  .CreateSymbol();
7786 }
7787 
7815  Symbol rhs) {
7816  return Operator("broadcast_add")
7817  .SetInput("lhs", lhs)
7818  .SetInput("rhs", rhs)
7819  .CreateSymbol();
7820 }
7821 
7849  Symbol rhs) {
7850  return Operator("broadcast_sub")
7851  .SetInput("lhs", lhs)
7852  .SetInput("rhs", rhs)
7853  .CreateSymbol();
7854 }
7855 
7878  Symbol rhs) {
7879  return Operator("broadcast_mul")
7880  .SetInput("lhs", lhs)
7881  .SetInput("rhs", rhs)
7882  .CreateSymbol();
7883 }
7884 
7907  Symbol rhs) {
7908  return Operator("broadcast_div")
7909  .SetInput("lhs", lhs)
7910  .SetInput("rhs", rhs)
7911  .CreateSymbol();
7912 }
7913 
7936  Symbol rhs) {
7937  return Operator("broadcast_mod")
7938  .SetInput("lhs", lhs)
7939  .SetInput("rhs", rhs)
7940  .CreateSymbol();
7941 }
7942 
7962 inline Symbol add_n(const std::vector<Symbol>& args) {
7963  return Operator("add_n")
7964 (args)
7965  .CreateSymbol();
7966 }
7967 
7998 inline Symbol argmax(Symbol data,
7999  dmlc::optional<int> axis = dmlc::optional<int>(),
8000  bool keepdims = false) {
8001  return Operator("argmax")
8002  .SetParam("axis", axis)
8003  .SetParam("keepdims", keepdims)
8004  .SetInput("data", data)
8005  .CreateSymbol();
8006 }
8007 
8038 inline Symbol argmin(Symbol data,
8039  dmlc::optional<int> axis = dmlc::optional<int>(),
8040  bool keepdims = false) {
8041  return Operator("argmin")
8042  .SetParam("axis", axis)
8043  .SetParam("keepdims", keepdims)
8044  .SetInput("data", data)
8045  .CreateSymbol();
8046 }
8047 
8070  return Operator("argmax_channel")
8071  .SetInput("data", data)
8072  .CreateSymbol();
8073 }
8074 
8119 inline Symbol pick(Symbol data,
8120  Symbol index,
8121  dmlc::optional<int> axis = dmlc::optional<int>(),
8122  bool keepdims = false) {
8123  return Operator("pick")
8124  .SetParam("axis", axis)
8125  .SetParam("keepdims", keepdims)
8126  .SetInput("data", data)
8127  .SetInput("index", index)
8128  .CreateSymbol();
8129 }
8130 
8170 inline Symbol dot(Symbol lhs,
8171  Symbol rhs,
8172  bool transpose_a = false,
8173  bool transpose_b = false) {
8174  return Operator("dot")
8175  .SetParam("transpose_a", transpose_a)
8176  .SetParam("transpose_b", transpose_b)
8177  .SetInput("lhs", lhs)
8178  .SetInput("rhs", rhs)
8179  .CreateSymbol();
8180 }
8181 
8204  Symbol rhs,
8205  bool transpose_a = false,
8206  bool transpose_b = false) {
8207  return Operator("batch_dot")
8208  .SetParam("transpose_a", transpose_a)
8209  .SetParam("transpose_b", transpose_b)
8210  .SetInput("lhs", lhs)
8211  .SetInput("rhs", rhs)
8212  .CreateSymbol();
8213 }
8214 
8232 inline Symbol relu(Symbol data) {
8233  return Operator("relu")
8234  .SetInput("data", data)
8235  .CreateSymbol();
8236 }
8237 
8252 inline Symbol sigmoid(Symbol data) {
8253  return Operator("sigmoid")
8254  .SetInput("data", data)
8255  .CreateSymbol();
8256 }
8257 
8290 inline Symbol BlockGrad(Symbol data) {
8291  return Operator("BlockGrad")
8292  .SetInput("data", data)
8293  .CreateSymbol();
8294 }
8295 
8324 inline Symbol make_loss(Symbol data) {
8325  return Operator("make_loss")
8326  .SetInput("data", data)
8327  .CreateSymbol();
8328 }
8329 
8337  Symbol rhs) {
8338  return Operator("reshape_like")
8339  .SetInput("lhs", lhs)
8340  .SetInput("rhs", rhs)
8341  .CreateSymbol();
8342 }
8343 
8362 inline Symbol Cast(Symbol data,
8363  CastDtype dtype) {
8364  static const char *CastDtypeValues[] = {
8365  "float16",
8366  "float32",
8367  "float64",
8368  "int32",
8369  "uint8"
8370  };
8371  return Operator("Cast")
8372  .SetParam("dtype", CastDtypeValues[int(dtype)])
8373  .SetInput("data", data)
8374  .CreateSymbol();
8375 }
8376 
8390 inline Symbol negative(Symbol data) {
8391  return Operator("negative")
8392  .SetInput("data", data)
8393  .CreateSymbol();
8394 }
8395 
8411 inline Symbol reciprocal(Symbol data) {
8412  return Operator("reciprocal")
8413  .SetInput("data", data)
8414  .CreateSymbol();
8415 }
8416 
8435 inline Symbol abs(Symbol data) {
8436  return Operator("abs")
8437  .SetInput("data", data)
8438  .CreateSymbol();
8439 }
8440 
8459 inline Symbol sign(Symbol data) {
8460  return Operator("sign")
8461  .SetInput("data", data)
8462  .CreateSymbol();
8463 }
8464 
8483 inline Symbol round(Symbol data) {
8484  return Operator("round")
8485  .SetInput("data", data)
8486  .CreateSymbol();
8487 }
8488 
8511 inline Symbol rint(Symbol data) {
8512  return Operator("rint")
8513  .SetInput("data", data)
8514  .CreateSymbol();
8515 }
8516 
8537 inline Symbol ceil(Symbol data) {
8538  return Operator("ceil")
8539  .SetInput("data", data)
8540  .CreateSymbol();
8541 }
8542 
8563 inline Symbol floor(Symbol data) {
8564  return Operator("floor")
8565  .SetInput("data", data)
8566  .CreateSymbol();
8567 }
8568 
8590 inline Symbol trunc(Symbol data) {
8591  return Operator("trunc")
8592  .SetInput("data", data)
8593  .CreateSymbol();
8594 }
8595 
8615 inline Symbol fix(Symbol data) {
8616  return Operator("fix")
8617  .SetInput("data", data)
8618  .CreateSymbol();
8619 }
8620 
8643 inline Symbol square(Symbol data) {
8644  return Operator("square")
8645  .SetInput("data", data)
8646  .CreateSymbol();
8647 }
8648 
8670 inline Symbol sqrt(Symbol data) {
8671  return Operator("sqrt")
8672  .SetInput("data", data)
8673  .CreateSymbol();
8674 }
8675 
8694 inline Symbol rsqrt(Symbol data) {
8695  return Operator("rsqrt")
8696  .SetInput("data", data)
8697  .CreateSymbol();
8698 }
8699 
8716 inline Symbol cbrt(Symbol data) {
8717  return Operator("cbrt")
8718  .SetInput("data", data)
8719  .CreateSymbol();
8720 }
8721 
8738 inline Symbol rcbrt(Symbol data) {
8739  return Operator("rcbrt")
8740  .SetInput("data", data)
8741  .CreateSymbol();
8742 }
8743 
8762 inline Symbol exp(Symbol data) {
8763  return Operator("exp")
8764  .SetInput("data", data)
8765  .CreateSymbol();
8766 }
8767 
8781 inline Symbol log(Symbol data) {
8782  return Operator("log")
8783  .SetInput("data", data)
8784  .CreateSymbol();
8785 }
8786 
8800 inline Symbol log10(Symbol data) {
8801  return Operator("log10")
8802  .SetInput("data", data)
8803  .CreateSymbol();
8804 }
8805 
8819 inline Symbol log2(Symbol data) {
8820  return Operator("log2")
8821  .SetInput("data", data)
8822  .CreateSymbol();
8823 }
8824 
8842 inline Symbol log1p(Symbol data) {
8843  return Operator("log1p")
8844  .SetInput("data", data)
8845  .CreateSymbol();
8846 }
8847 
8864 inline Symbol expm1(Symbol data) {
8865  return Operator("expm1")
8866  .SetInput("data", data)
8867  .CreateSymbol();
8868 }
8869 
8880 inline Symbol gamma(Symbol data) {
8881  return Operator("gamma")
8882  .SetInput("data", data)
8883  .CreateSymbol();
8884 }
8885 
8896 inline Symbol gammaln(Symbol data) {
8897  return Operator("gammaln")
8898  .SetInput("data", data)
8899  .CreateSymbol();
8900 }
8901 
8959 inline Symbol sum(Symbol data,
8960  Shape axis = Shape(),
8961  bool keepdims = false,
8962  bool exclude = false) {
8963  return Operator("sum")
8964  .SetParam("axis", axis)
8965  .SetParam("keepdims", keepdims)
8966  .SetParam("exclude", exclude)
8967  .SetInput("data", data)
8968  .CreateSymbol();
8969 }
8970 
8994 inline Symbol mean(Symbol data,
8995  Shape axis = Shape(),
8996  bool keepdims = false,
8997  bool exclude = false) {
8998  return Operator("mean")
8999  .SetParam("axis", axis)
9000  .SetParam("keepdims", keepdims)
9001  .SetParam("exclude", exclude)
9002  .SetInput("data", data)
9003  .CreateSymbol();
9004 }
9005 
9029 inline Symbol prod(Symbol data,
9030  Shape axis = Shape(),
9031  bool keepdims = false,
9032  bool exclude = false) {
9033  return Operator("prod")
9034  .SetParam("axis", axis)
9035  .SetParam("keepdims", keepdims)
9036  .SetParam("exclude", exclude)
9037  .SetInput("data", data)
9038  .CreateSymbol();
9039 }
9040 
9066 inline Symbol nansum(Symbol data,
9067  Shape axis = Shape(),
9068  bool keepdims = false,
9069  bool exclude = false) {
9070  return Operator("nansum")
9071  .SetParam("axis", axis)
9072  .SetParam("keepdims", keepdims)
9073  .SetParam("exclude", exclude)
9074  .SetInput("data", data)
9075  .CreateSymbol();
9076 }
9077 
9103 inline Symbol nanprod(Symbol data,
9104  Shape axis = Shape(),
9105  bool keepdims = false,
9106  bool exclude = false) {
9107  return Operator("nanprod")
9108  .SetParam("axis", axis)
9109  .SetParam("keepdims", keepdims)
9110  .SetParam("exclude", exclude)
9111  .SetInput("data", data)
9112  .CreateSymbol();
9113 }
9114 
9138 inline Symbol max(Symbol data,
9139  Shape axis = Shape(),
9140  bool keepdims = false,
9141  bool exclude = false) {
9142  return Operator("max")
9143  .SetParam("axis", axis)
9144  .SetParam("keepdims", keepdims)
9145  .SetParam("exclude", exclude)
9146  .SetInput("data", data)
9147  .CreateSymbol();
9148 }
9149 
9173 inline Symbol min(Symbol data,
9174  Shape axis = Shape(),
9175  bool keepdims = false,
9176  bool exclude = false) {
9177  return Operator("min")
9178  .SetParam("axis", axis)
9179  .SetParam("keepdims", keepdims)
9180  .SetParam("exclude", exclude)
9181  .SetInput("data", data)
9182  .CreateSymbol();
9183 }
9184 
9214  Shape axis = Shape(),
9215  Shape size = Shape()) {
9216  return Operator("broadcast_axis")
9217  .SetParam("axis", axis)
9218  .SetParam("size", size)
9219  .SetInput("data", data)
9220  .CreateSymbol();
9221 }
9222 
9251  Shape shape = Shape()) {
9252  return Operator("broadcast_to")
9253  .SetParam("shape", shape)
9254  .SetInput("data", data)
9255  .CreateSymbol();
9256 }
9257 
9282 inline Symbol norm(Symbol data) {
9283  return Operator("norm")
9284  .SetInput("data", data)
9285  .CreateSymbol();
9286 }
9287 
9328 inline Symbol topk(Symbol data,
9329  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9330  int k = 1,
9331  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
9332  bool is_ascend = false) {
9333  static const char *TopkRetTypValues[] = {
9334  "both",
9335  "indices",
9336  "mask",
9337  "value"
9338  };
9339  return Operator("topk")
9340  .SetParam("axis", axis)
9341  .SetParam("k", k)
9342  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
9343  .SetParam("is_ascend", is_ascend)
9344  .SetInput("data", data)
9345  .CreateSymbol();
9346 }
9347 
9379 inline Symbol sort(Symbol data,
9380  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9381  bool is_ascend = true) {
9382  return Operator("sort")
9383  .SetParam("axis", axis)
9384  .SetParam("is_ascend", is_ascend)
9385  .SetInput("data", data)
9386  .CreateSymbol();
9387 }
9388 
9418 inline Symbol argsort(Symbol data,
9419  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9420  bool is_ascend = true) {
9421  return Operator("argsort")
9422  .SetParam("axis", axis)
9423  .SetParam("is_ascend", is_ascend)
9424  .SetInput("data", data)
9425  .CreateSymbol();
9426 }
9427 
9443  Symbol rhs) {
9444  return Operator("elemwise_add")
9445  .SetInput("lhs", lhs)
9446  .SetInput("rhs", rhs)
9447  .CreateSymbol();
9448 }
9449 
9465  Symbol rhs) {
9466  return Operator("elemwise_sub")
9467  .SetInput("lhs", lhs)
9468  .SetInput("rhs", rhs)
9469  .CreateSymbol();
9470 }
9471 
9490  Symbol rhs) {
9491  return Operator("elemwise_mul")
9492  .SetInput("lhs", lhs)
9493  .SetInput("rhs", rhs)
9494  .CreateSymbol();
9495 }
9496 
9508  Symbol rhs) {
9509  return Operator("elemwise_div")
9510  .SetInput("lhs", lhs)
9511  .SetInput("rhs", rhs)
9512  .CreateSymbol();
9513 }
9514 
9566  Symbol weight,
9567  int input_dim,
9568  int output_dim,
9570  static const char *EmbeddingDtypeValues[] = {
9571  "float16",
9572  "float32",
9573  "float64",
9574  "int32",
9575  "uint8"
9576  };
9577  return Operator("Embedding")
9578  .SetParam("input_dim", input_dim)
9579  .SetParam("output_dim", output_dim)
9580  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
9581  .SetInput("data", data)
9582  .SetInput("weight", weight)
9583  .CreateSymbol();
9584 }
9585 
9624 inline Symbol take(Symbol a,
9625  Symbol indices,
9626  int axis = 0,
9627  TakeMode mode = TakeMode::kClip) {
9628  static const char *TakeModeValues[] = {
9629  "clip",
9630  "raise",
9631  "wrap"
9632  };
9633  return Operator("take")
9634  .SetParam("axis", axis)
9635  .SetParam("mode", TakeModeValues[int(mode)])
9636  .SetInput("a", a)
9637  .SetInput("indices", indices)
9638  .CreateSymbol();
9639 }
9640 
9669  Symbol indices) {
9670  return Operator("batch_take")
9671  .SetInput("a", a)
9672  .SetInput("indices", indices)
9673  .CreateSymbol();
9674 }
9675 
9719 inline Symbol one_hot(Symbol indices,
9720  int depth,
9721  double on_value = 1,
9722  double off_value = 0,
9724  static const char *One_hotDtypeValues[] = {
9725  "float16",
9726  "float32",
9727  "float64",
9728  "int32",
9729  "uint8"
9730  };
9731  return Operator("one_hot")
9732  .SetParam("depth", depth)
9733  .SetParam("on_value", on_value)
9734  .SetParam("off_value", off_value)
9735  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
9736  .SetInput("indices", indices)
9737  .CreateSymbol();
9738 }
9739 
9767  Symbol indices) {
9768  return Operator("gather_nd")
9769  .SetInput("data", data)
9770  .SetInput("indices", indices)
9771  .CreateSymbol();
9772 }
9773 
9810  Symbol indices,
9811  Shape shape) {
9812  return Operator("scatter_nd")
9813  .SetParam("shape", shape)
9814  .SetInput("data", data)
9815  .SetInput("indices", indices)
9816  .CreateSymbol();
9817 }
9818 
9841  Symbol rhs) {
9842  return Operator("broadcast_equal")
9843  .SetInput("lhs", lhs)
9844  .SetInput("rhs", rhs)
9845  .CreateSymbol();
9846 }
9847 
9870  Symbol rhs) {
9871  return Operator("broadcast_not_equal")
9872  .SetInput("lhs", lhs)
9873  .SetInput("rhs", rhs)
9874  .CreateSymbol();
9875 }
9876 
9899  Symbol rhs) {
9900  return Operator("broadcast_greater")
9901  .SetInput("lhs", lhs)
9902  .SetInput("rhs", rhs)
9903  .CreateSymbol();
9904 }
9905 
9928  Symbol rhs) {
9929  return Operator("broadcast_greater_equal")
9930  .SetInput("lhs", lhs)
9931  .SetInput("rhs", rhs)
9932  .CreateSymbol();
9933 }
9934 
9957  Symbol rhs) {
9958  return Operator("broadcast_lesser")
9959  .SetInput("lhs", lhs)
9960  .SetInput("rhs", rhs)
9961  .CreateSymbol();
9962 }
9963 
9986  Symbol rhs) {
9987  return Operator("broadcast_lesser_equal")
9988  .SetInput("lhs", lhs)
9989  .SetInput("rhs", rhs)
9990  .CreateSymbol();
9991 }
9992 
10008 inline Symbol where(Symbol condition,
10009  Symbol x,
10010  Symbol y) {
10011  return Operator("where")
10012  .SetInput("condition", condition)
10013  .SetInput("x", x)
10014  .SetInput("y", y)
10015  .CreateSymbol();
10016 }
10017 
10043  mx_float scalar) {
10044  return Operator("smooth_l1")
10045  .SetParam("scalar", scalar)
10046  .SetInput("data", data)
10047  .CreateSymbol();
10048 }
10049 
10093  Cast_storageStype stype) {
10094  static const char *Cast_storageStypeValues[] = {
10095  "csr",
10096  "default",
10097  "row_sparse"
10098  };
10099  return Operator("cast_storage")
10100  .SetParam("stype", Cast_storageStypeValues[int(stype)])
10101  .SetInput("data", data)
10102  .CreateSymbol();
10103 }
10104 
10124 inline Symbol sin(Symbol data) {
10125  return Operator("sin")
10126  .SetInput("data", data)
10127  .CreateSymbol();
10128 }
10129 
10146 inline Symbol cos(Symbol data) {
10147  return Operator("cos")
10148  .SetInput("data", data)
10149  .CreateSymbol();
10150 }
10151 
10171 inline Symbol tan(Symbol data) {
10172  return Operator("tan")
10173  .SetInput("data", data)
10174  .CreateSymbol();
10175 }
10176 
10197 inline Symbol arcsin(Symbol data) {
10198  return Operator("arcsin")
10199  .SetInput("data", data)
10200  .CreateSymbol();
10201 }
10202 
10220 inline Symbol arccos(Symbol data) {
10221  return Operator("arccos")
10222  .SetInput("data", data)
10223  .CreateSymbol();
10224 }
10225 
10245 inline Symbol arctan(Symbol data) {
10246  return Operator("arctan")
10247  .SetInput("data", data)
10248  .CreateSymbol();
10249 }
10250 
10268 inline Symbol degrees(Symbol data) {
10269  return Operator("degrees")
10270  .SetInput("data", data)
10271  .CreateSymbol();
10272 }
10273 
10291 inline Symbol radians(Symbol data) {
10292  return Operator("radians")
10293  .SetInput("data", data)
10294  .CreateSymbol();
10295 }
10296 
10314 inline Symbol sinh(Symbol data) {
10315  return Operator("sinh")
10316  .SetInput("data", data)
10317  .CreateSymbol();
10318 }
10319 
10334 inline Symbol cosh(Symbol data) {
10335  return Operator("cosh")
10336  .SetInput("data", data)
10337  .CreateSymbol();
10338 }
10339 
10357 inline Symbol tanh(Symbol data) {
10358  return Operator("tanh")
10359  .SetInput("data", data)
10360  .CreateSymbol();
10361 }
10362 
10378 inline Symbol arcsinh(Symbol data) {
10379  return Operator("arcsinh")
10380  .SetInput("data", data)
10381  .CreateSymbol();
10382 }
10383 
10396 inline Symbol arccosh(Symbol data) {
10397  return Operator("arccosh")
10398  .SetInput("data", data)
10399  .CreateSymbol();
10400 }
10401 
10417 inline Symbol arctanh(Symbol data) {
10418  return Operator("arctanh")
10419  .SetInput("data", data)
10420  .CreateSymbol();
10421 }
10422 
10451 inline Symbol softmax(Symbol data,
10452  int axis = -1) {
10453  return Operator("softmax")
10454  .SetParam("axis", axis)
10455  .SetInput("data", data)
10456  .CreateSymbol();
10457 }
10458 
10481  int axis = -1) {
10482  return Operator("log_softmax")
10483  .SetParam("axis", axis)
10484  .SetInput("data", data)
10485  .CreateSymbol();
10486 }
10487 
10502 inline Symbol UpSampling(const std::vector<Symbol>& data,
10503  uint32_t scale,
10504  UpSamplingSampleType sample_type,
10505  int num_args,
10506  uint32_t num_filter = 0,
10508  uint64_t workspace = 512) {
10509  static const char *UpSamplingSampleTypeValues[] = {
10510  "bilinear",
10511  "nearest"
10512  };
10513  static const char *UpSamplingMultiInputModeValues[] = {
10514  "concat",
10515  "sum"
10516  };
10517  return Operator("UpSampling")
10518  .SetParam("scale", scale)
10519  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
10520  .SetParam("num_args", num_args)
10521  .SetParam("num_filter", num_filter)
10522  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
10523  .SetParam("workspace", workspace)
10524 (data)
10525  .CreateSymbol();
10526 }
10527 
10591  Symbol gamma,
10592  Symbol beta,
10593  Symbol moving_mean,
10594  Symbol moving_var,
10595  double eps = 0.001,
10596  mx_float momentum = 0.9,
10597  bool fix_gamma = true,
10598  bool use_global_stats = false,
10599  bool output_mean_var = false,
10600  int axis = 1,
10601  bool cudnn_off = false) {
10602  return Operator("BatchNorm")
10603  .SetParam("eps", eps)
10604  .SetParam("momentum", momentum)
10605  .SetParam("fix_gamma", fix_gamma)
10606  .SetParam("use_global_stats", use_global_stats)
10607  .SetParam("output_mean_var", output_mean_var)
10608  .SetParam("axis", axis)
10609  .SetParam("cudnn_off", cudnn_off)
10610  .SetInput("data", data)
10611  .SetInput("gamma", gamma)
10612  .SetInput("beta", beta)
10613  .SetInput("moving_mean", moving_mean)
10614  .SetInput("moving_var", moving_var)
10615  .CreateSymbol();
10616 }
10617 
10713 inline Symbol Pad(Symbol data,
10714  PadMode mode,
10715  Shape pad_width,
10716  double constant_value = 0) {
10717  static const char *PadModeValues[] = {
10718  "constant",
10719  "edge",
10720  "reflect"
10721  };
10722  return Operator("Pad")
10723  .SetParam("mode", PadModeValues[int(mode)])
10724  .SetParam("pad_width", pad_width)
10725  .SetParam("constant_value", constant_value)
10726  .SetInput("data", data)
10727  .CreateSymbol();
10728 }
10729 
10770 inline Symbol Concat(const std::vector<Symbol>& data,
10771  int num_args,
10772  int dim = 1) {
10773  return Operator("Concat")
10774  .SetParam("num_args", num_args)
10775  .SetParam("dim", dim)
10776 (data)
10777  .CreateSymbol();
10778 }
10779 
10807  mx_float slope = 0.25,
10808  mx_float lower_bound = 0.125,
10809  mx_float upper_bound = 0.334) {
10810  static const char *LeakyReLUActTypeValues[] = {
10811  "elu",
10812  "leaky",
10813  "prelu",
10814  "rrelu"
10815  };
10816  return Operator("LeakyReLU")
10817  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
10818  .SetParam("slope", slope)
10819  .SetParam("lower_bound", lower_bound)
10820  .SetParam("upper_bound", upper_bound)
10821  .SetInput("data", data)
10822  .CreateSymbol();
10823 }
10824 
10896  int num_outputs,
10897  int axis = 1,
10898  bool squeeze_axis = false) {
10899  return Operator("SliceChannel")
10900  .SetParam("num_outputs", num_outputs)
10901  .SetParam("axis", axis)
10902  .SetParam("squeeze_axis", squeeze_axis)
10903  .SetInput("data", data)
10904  .CreateSymbol();
10905 }
10906 
10934 inline Symbol SwapAxis(Symbol data,
10935  uint32_t dim1 = 0,
10936  uint32_t dim2 = 0) {
10937  return Operator("SwapAxis")
10938  .SetParam("dim1", dim1)
10939  .SetParam("dim2", dim2)
10940  .SetInput("data", data)
10941  .CreateSymbol();
10942 }
10943 
10999  Symbol gamma,
11000  Symbol beta,
11001  mx_float eps = 0.001,
11002  mx_float momentum = 0.9,
11003  bool fix_gamma = true,
11004  bool use_global_stats = false,
11005  bool output_mean_var = false) {
11006  return Operator("BatchNorm_v1")
11007  .SetParam("eps", eps)
11008  .SetParam("momentum", momentum)
11009  .SetParam("fix_gamma", fix_gamma)
11010  .SetParam("use_global_stats", use_global_stats)
11011  .SetParam("output_mean_var", output_mean_var)
11012  .SetInput("data", data)
11013  .SetInput("gamma", gamma)
11014  .SetInput("beta", beta)
11015  .CreateSymbol();
11016 }
11017 
11055  Symbol label) {
11056  return Operator("softmax_cross_entropy")
11057  .SetInput("data", data)
11058  .SetInput("label", label)
11059  .CreateSymbol();
11060 }
11061 
11086  Symbol label,
11087  mx_float grad_scale = 1) {
11088  return Operator("LinearRegressionOutput")
11089  .SetParam("grad_scale", grad_scale)
11090  .SetInput("data", data)
11091  .SetInput("label", label)
11092  .CreateSymbol();
11093 }
11094 
11120  Symbol label,
11121  mx_float grad_scale = 1) {
11122  return Operator("MAERegressionOutput")
11123  .SetParam("grad_scale", grad_scale)
11124  .SetInput("data", data)
11125  .SetInput("label", label)
11126  .CreateSymbol();
11127 }
11128 
11154  Symbol label,
11155  mx_float grad_scale = 1) {
11156  return Operator("LogisticRegressionOutput")
11157  .SetParam("grad_scale", grad_scale)
11158  .SetInput("data", data)
11159  .SetInput("label", label)
11160  .CreateSymbol();
11161 }
11162 
11172  mx_float sparseness_target = 0.1,
11173  mx_float penalty = 0.001,
11174  mx_float momentum = 0.9) {
11175  return Operator("IdentityAttachKLSparseReg")
11176  .SetParam("sparseness_target", sparseness_target)
11177  .SetParam("penalty", penalty)
11178  .SetParam("momentum", momentum)
11179  .SetInput("data", data)
11180  .CreateSymbol();
11181 }
11182 
11210  Symbol grad,
11211  mx_float lr,
11212  mx_float wd = 0,
11213  mx_float rescale_grad = 1,
11214  mx_float clip_gradient = -1) {
11215  return Operator("signsgd_update")
11216  .SetParam("lr", lr)
11217  .SetParam("wd", wd)
11218  .SetParam("rescale_grad", rescale_grad)
11219  .SetParam("clip_gradient", clip_gradient)
11220  .SetInput("weight", weight)
11221  .SetInput("grad", grad)
11222  .CreateSymbol();
11223 }
11224 
11259  Symbol grad,
11260  Symbol mom,
11261  mx_float lr,
11262  mx_float momentum = 0,
11263  mx_float wd = 0,
11264  mx_float rescale_grad = 1,
11265  mx_float clip_gradient = -1,
11266  mx_float wd_lh = 0) {
11267  return Operator("signum_update")
11268  .SetParam("lr", lr)
11269  .SetParam("momentum", momentum)
11270  .SetParam("wd", wd)
11271  .SetParam("rescale_grad", rescale_grad)
11272  .SetParam("clip_gradient", clip_gradient)
11273  .SetParam("wd_lh", wd_lh)
11274  .SetInput("weight", weight)
11275  .SetInput("grad", grad)
11276  .SetInput("mom", mom)
11277  .CreateSymbol();
11278 }
11279 
11306 inline Symbol sgd_update(Symbol weight,
11307  Symbol grad,
11308  mx_float lr,
11309  mx_float wd = 0,
11310  mx_float rescale_grad = 1,
11311  mx_float clip_gradient = -1) {
11312  return Operator("sgd_update")
11313  .SetParam("lr", lr)
11314  .SetParam("wd", wd)
11315  .SetParam("rescale_grad", rescale_grad)
11316  .SetParam("clip_gradient", clip_gradient)
11317  .SetInput("weight", weight)
11318  .SetInput("grad", grad)
11319  .CreateSymbol();
11320 }
11321 
11367  Symbol grad,
11368  Symbol mom,
11369  mx_float lr,
11370  mx_float momentum = 0,
11371  mx_float wd = 0,
11372  mx_float rescale_grad = 1,
11373  mx_float clip_gradient = -1) {
11374  return Operator("sgd_mom_update")
11375  .SetParam("lr", lr)
11376  .SetParam("momentum", momentum)
11377  .SetParam("wd", wd)
11378  .SetParam("rescale_grad", rescale_grad)
11379  .SetParam("clip_gradient", clip_gradient)
11380  .SetInput("weight", weight)
11381  .SetInput("grad", grad)
11382  .SetInput("mom", mom)
11383  .CreateSymbol();
11384 }
11385 
11400  Symbol grad,
11401  Symbol weight32,
11402  mx_float lr,
11403  mx_float wd = 0,
11404  mx_float rescale_grad = 1,
11405  mx_float clip_gradient = -1) {
11406  return Operator("mp_sgd_update")
11407  .SetParam("lr", lr)
11408  .SetParam("wd", wd)
11409  .SetParam("rescale_grad", rescale_grad)
11410  .SetParam("clip_gradient", clip_gradient)
11411  .SetInput("weight", weight)
11412  .SetInput("grad", grad)
11413  .SetInput("weight32", weight32)
11414  .CreateSymbol();
11415 }
11416 
11433  Symbol grad,
11434  Symbol mom,
11435  Symbol weight32,
11436  mx_float lr,
11437  mx_float momentum = 0,
11438  mx_float wd = 0,
11439  mx_float rescale_grad = 1,
11440  mx_float clip_gradient = -1) {
11441  return Operator("mp_sgd_mom_update")
11442  .SetParam("lr", lr)
11443  .SetParam("momentum", momentum)
11444  .SetParam("wd", wd)
11445  .SetParam("rescale_grad", rescale_grad)
11446  .SetParam("clip_gradient", clip_gradient)
11447  .SetInput("weight", weight)
11448  .SetInput("grad", grad)
11449  .SetInput("mom", mom)
11450  .SetInput("weight32", weight32)
11451  .CreateSymbol();
11452 }
11453 
11487 inline Symbol ftml_update(Symbol weight,
11488  Symbol grad,
11489  Symbol d,
11490  Symbol v,
11491  Symbol z,
11492  mx_float lr,
11493  mx_float beta1 = 0.9,
11494  mx_float beta2 = 0.999,
11495  mx_float epsilon = 1e-08,
11496  mx_float wd = 0,
11497  mx_float rescale_grad = 1,
11498  mx_float clip_gradient = -1) {
11499  return Operator("ftml_update")
11500  .SetParam("lr", lr)
11501  .SetParam("beta1", beta1)
11502  .SetParam("beta2", beta2)
11503  .SetParam("epsilon", epsilon)
11504  .SetParam("wd", wd)
11505  .SetParam("rescale_grad", rescale_grad)
11506  .SetParam("clip_gradient", clip_gradient)
11507  .SetInput("weight", weight)
11508  .SetInput("grad", grad)
11509  .SetInput("d", d)
11510  .SetInput("v", v)
11511  .SetInput("z", z)
11512  .CreateSymbol();
11513 }
11514 
11561 inline Symbol adam_update(Symbol weight,
11562  Symbol grad,
11563  Symbol mean,
11564  Symbol var,
11565  mx_float lr,
11566  mx_float beta1 = 0.9,
11567  mx_float beta2 = 0.999,
11568  mx_float epsilon = 1e-08,
11569  mx_float wd = 0,
11570  mx_float rescale_grad = 1,
11571  mx_float clip_gradient = -1) {
11572  return Operator("adam_update")
11573  .SetParam("lr", lr)
11574  .SetParam("beta1", beta1)
11575  .SetParam("beta2", beta2)
11576  .SetParam("epsilon", epsilon)
11577  .SetParam("wd", wd)
11578  .SetParam("rescale_grad", rescale_grad)
11579  .SetParam("clip_gradient", clip_gradient)
11580  .SetInput("weight", weight)
11581  .SetInput("grad", grad)
11582  .SetInput("mean", mean)
11583  .SetInput("var", var)
11584  .CreateSymbol();
11585 }
11586 
11640  Symbol grad,
11641  Symbol n,
11642  mx_float lr,
11643  mx_float gamma1 = 0.95,
11644  mx_float epsilon = 1e-08,
11645  mx_float wd = 0,
11646  mx_float rescale_grad = 1,
11647  mx_float clip_gradient = -1,
11648  mx_float clip_weights = -1) {
11649  return Operator("rmsprop_update")
11650  .SetParam("lr", lr)
11651  .SetParam("gamma1", gamma1)
11652  .SetParam("epsilon", epsilon)
11653  .SetParam("wd", wd)
11654  .SetParam("rescale_grad", rescale_grad)
11655  .SetParam("clip_gradient", clip_gradient)
11656  .SetParam("clip_weights", clip_weights)
11657  .SetInput("weight", weight)
11658  .SetInput("grad", grad)
11659  .SetInput("n", n)
11660  .CreateSymbol();
11661 }
11662 
11708  Symbol grad,
11709  Symbol n,
11710  Symbol g,
11711  Symbol delta,
11712  mx_float lr,
11713  mx_float gamma1 = 0.95,
11714  mx_float gamma2 = 0.9,
11715  mx_float epsilon = 1e-08,
11716  mx_float wd = 0,
11717  mx_float rescale_grad = 1,
11718  mx_float clip_gradient = -1,
11719  mx_float clip_weights = -1) {
11720  return Operator("rmspropalex_update")
11721  .SetParam("lr", lr)
11722  .SetParam("gamma1", gamma1)
11723  .SetParam("gamma2", gamma2)
11724  .SetParam("epsilon", epsilon)
11725  .SetParam("wd", wd)
11726  .SetParam("rescale_grad", rescale_grad)
11727  .SetParam("clip_gradient", clip_gradient)
11728  .SetParam("clip_weights", clip_weights)
11729  .SetInput("weight", weight)
11730  .SetInput("grad", grad)
11731  .SetInput("n", n)
11732  .SetInput("g", g)
11733  .SetInput("delta", delta)
11734  .CreateSymbol();
11735 }
11736 
11775 inline Symbol ftrl_update(Symbol weight,
11776  Symbol grad,
11777  Symbol z,
11778  Symbol n,
11779  mx_float lr,
11780  mx_float lamda1 = 0.01,
11781  mx_float beta = 1,
11782  mx_float wd = 0,
11783  mx_float rescale_grad = 1,
11784  mx_float clip_gradient = -1) {
11785  return Operator("ftrl_update")
11786  .SetParam("lr", lr)
11787  .SetParam("lamda1", lamda1)
11788  .SetParam("beta", beta)
11789  .SetParam("wd", wd)
11790  .SetParam("rescale_grad", rescale_grad)
11791  .SetParam("clip_gradient", clip_gradient)
11792  .SetInput("weight", weight)
11793  .SetInput("grad", grad)
11794  .SetInput("z", z)
11795  .SetInput("n", n)
11796  .CreateSymbol();
11797 }
11798 
11851 inline Symbol Pooling(Symbol data,
11852  Shape kernel,
11853  PoolingPoolType pool_type,
11854  bool global_pool = false,
11855  bool cudnn_off = false,
11857  Shape stride = Shape(),
11858  Shape pad = Shape()) {
11859  static const char *PoolingPoolTypeValues[] = {
11860  "avg",
11861  "max",
11862  "sum"
11863  };
11864  static const char *PoolingPoolingConventionValues[] = {
11865  "full",
11866  "valid"
11867  };
11868  return Operator("Pooling")
11869  .SetParam("kernel", kernel)
11870  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
11871  .SetParam("global_pool", global_pool)
11872  .SetParam("cudnn_off", cudnn_off)
11873  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
11874  .SetParam("stride", stride)
11875  .SetParam("pad", pad)
11876  .SetInput("data", data)
11877  .CreateSymbol();
11878 }
11879 
11906  Symbol weight,
11907  Symbol bias,
11908  Shape kernel,
11909  uint32_t num_filter,
11910  Shape stride = Shape(),
11911  Shape dilate = Shape(),
11912  Shape pad = Shape(),
11913  Shape adj = Shape(),
11914  Shape target_shape = Shape(),
11915  uint32_t num_group = 1,
11916  uint64_t workspace = 512,
11917  bool no_bias = true,
11919  bool cudnn_off = false,
11921  static const char *DeconvolutionCudnnTuneValues[] = {
11922  "None",
11923  "fastest",
11924  "limited_workspace",
11925  "off"
11926  };
11927  static const char *DeconvolutionLayoutValues[] = {
11928  "None",
11929  "NCDHW",
11930  "NCHW",
11931  "NCW",
11932  "NDHWC",
11933  "NHWC"
11934  };
11935  return Operator("Deconvolution")
11936  .SetParam("kernel", kernel)
11937  .SetParam("num_filter", num_filter)
11938  .SetParam("stride", stride)
11939  .SetParam("dilate", dilate)
11940  .SetParam("pad", pad)
11941  .SetParam("adj", adj)
11942  .SetParam("target_shape", target_shape)
11943  .SetParam("num_group", num_group)
11944  .SetParam("workspace", workspace)
11945  .SetParam("no_bias", no_bias)
11946  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
11947  .SetParam("cudnn_off", cudnn_off)
11948  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
11949  .SetInput("data", data)
11950  .SetInput("weight", weight)
11951  .SetInput("bias", bias)
11952  .CreateSymbol();
11953 }
11954 
11973  ActivationActType act_type) {
11974  static const char *ActivationActTypeValues[] = {
11975  "relu",
11976  "sigmoid",
11977  "softrelu",
11978  "tanh"
11979  };
11980  return Operator("Activation")
11981  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
11982  .SetInput("data", data)
11983  .CreateSymbol();
11984 }
11985 
12080  Symbol weight,
12081  Symbol bias,
12082  Shape kernel,
12083  uint32_t num_filter,
12084  Shape stride = Shape(),
12085  Shape dilate = Shape(),
12086  Shape pad = Shape(),
12087  uint32_t num_group = 1,
12088  uint64_t workspace = 1024,
12089  bool no_bias = false,
12091  bool cudnn_off = false,
12093  static const char *ConvolutionCudnnTuneValues[] = {
12094  "None",
12095  "fastest",
12096  "limited_workspace",
12097  "off"
12098  };
12099  static const char *ConvolutionLayoutValues[] = {
12100  "None",
12101  "NCDHW",
12102  "NCHW",
12103  "NCW",
12104  "NDHWC",
12105  "NHWC"
12106  };
12107  return Operator("Convolution")
12108  .SetParam("kernel", kernel)
12109  .SetParam("num_filter", num_filter)
12110  .SetParam("stride", stride)
12111  .SetParam("dilate", dilate)
12112  .SetParam("pad", pad)
12113  .SetParam("num_group", num_group)
12114  .SetParam("workspace", workspace)
12115  .SetParam("no_bias", no_bias)
12116  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
12117  .SetParam("cudnn_off", cudnn_off)
12118  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
12119  .SetInput("data", data)
12120  .SetInput("weight", weight)
12121  .SetInput("bias", bias)
12122  .CreateSymbol();
12123 }
12124 
12163 inline Symbol Dropout(Symbol data,
12164  mx_float p = 0.5,
12166  static const char *DropoutModeValues[] = {
12167  "always",
12168  "training"
12169  };
12170  return Operator("Dropout")
12171  .SetParam("p", p)
12172  .SetParam("mode", DropoutModeValues[int(mode)])
12173  .SetInput("data", data)
12174  .CreateSymbol();
12175 }
12176 
12211  static const char *SoftmaxActivationModeValues[] = {
12212  "channel",
12213  "instance"
12214  };
12215  return Operator("SoftmaxActivation")
12216  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
12217  .SetInput("data", data)
12218  .CreateSymbol();
12219 }
12220 
12254  Symbol weight,
12255  Symbol bias,
12256  int num_hidden,
12257  bool no_bias = false,
12258  bool flatten = true) {
12259  return Operator("FullyConnected")
12260  .SetParam("num_hidden", num_hidden)
12261  .SetParam("no_bias", no_bias)
12262  .SetParam("flatten", flatten)
12263  .SetInput("data", data)
12264  .SetInput("weight", weight)
12265  .SetInput("bias", bias)
12266  .CreateSymbol();
12267 }
12268 
12319  Symbol gamma,
12320  Symbol beta,
12321  mx_float eps = 0.001) {
12322  return Operator("InstanceNorm")
12323  .SetParam("eps", eps)
12324  .SetInput("data", data)
12325  .SetInput("gamma", gamma)
12326  .SetInput("beta", beta)
12327  .CreateSymbol();
12328 }
12329 
12340  GridGeneratorTransformType transform_type,
12341  Shape target_shape = Shape(0,0)) {
12342  static const char *GridGeneratorTransformTypeValues[] = {
12343  "affine",
12344  "warp"
12345  };
12346  return Operator("GridGenerator")
12347  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
12348  .SetParam("target_shape", target_shape)
12349  .SetInput("data", data)
12350  .CreateSymbol();
12351 }
12352 
12404  Shape kernel,
12405  Pooling_v1PoolType pool_type,
12406  bool global_pool = false,
12408  Shape stride = Shape(),
12409  Shape pad = Shape()) {
12410  static const char *Pooling_v1PoolTypeValues[] = {
12411  "avg",
12412  "max",
12413  "sum"
12414  };
12415  static const char *Pooling_v1PoolingConventionValues[] = {
12416  "full",
12417  "valid"
12418  };
12419  return Operator("Pooling_v1")
12420  .SetParam("kernel", kernel)
12421  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
12422  .SetParam("global_pool", global_pool)
12423  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
12424  .SetParam("stride", stride)
12425  .SetParam("pad", pad)
12426  .SetInput("data", data)
12427  .CreateSymbol();
12428 }
12429 
12444 inline Symbol RNN(Symbol data,
12445  Symbol parameters,
12446  Symbol state,
12447  Symbol state_cell,
12448  uint32_t state_size,
12449  uint32_t num_layers,
12450  RNNMode mode,
12451  bool bidirectional = false,
12452  mx_float p = 0,
12453  bool state_outputs = false) {
12454  static const char *RNNModeValues[] = {
12455  "gru",
12456  "lstm",
12457  "rnn_relu",
12458  "rnn_tanh"
12459  };
12460  return Operator("RNN")
12461  .SetParam("state_size", state_size)
12462  .SetParam("num_layers", num_layers)
12463  .SetParam("mode", RNNModeValues[int(mode)])
12464  .SetParam("bidirectional", bidirectional)
12465  .SetParam("p", p)
12466  .SetParam("state_outputs", state_outputs)
12467  .SetInput("data", data)
12468  .SetInput("parameters", parameters)
12469  .SetInput("state", state)
12470  .SetInput("state_cell", state_cell)
12471  .CreateSymbol();
12472 }
12473 
12502  Symbol weight,
12503  Symbol bias,
12504  Shape kernel,
12505  uint32_t num_filter,
12506  Shape stride = Shape(),
12507  Shape dilate = Shape(),
12508  Shape pad = Shape(),
12509  uint32_t num_group = 1,
12510  uint64_t workspace = 1024,
12511  bool no_bias = false,
12513  bool cudnn_off = false,
12515  static const char *Convolution_v1CudnnTuneValues[] = {
12516  "None",
12517  "fastest",
12518  "limited_workspace",
12519  "off"
12520  };
12521  static const char *Convolution_v1LayoutValues[] = {
12522  "None",
12523  "NCDHW",
12524  "NCHW",
12525  "NDHWC",
12526  "NHWC"
12527  };
12528  return Operator("Convolution_v1")
12529  .SetParam("kernel", kernel)
12530  .SetParam("num_filter", num_filter)
12531  .SetParam("stride", stride)
12532  .SetParam("dilate", dilate)
12533  .SetParam("pad", pad)
12534  .SetParam("num_group", num_group)
12535  .SetParam("workspace", workspace)
12536  .SetParam("no_bias", no_bias)
12537  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
12538  .SetParam("cudnn_off", cudnn_off)
12539  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
12540  .SetInput("data", data)
12541  .SetInput("weight", weight)
12542  .SetInput("bias", bias)
12543  .CreateSymbol();
12544 }
12545 
12565 inline Symbol Crop(const std::vector<Symbol>& data,
12566  int num_args,
12567  Shape offset = Shape(0,0),
12568  Shape h_w = Shape(0,0),
12569  bool center_crop = false) {
12570  return Operator("Crop")
12571  .SetParam("num_args", num_args)
12572  .SetParam("offset", offset)
12573  .SetParam("h_w", h_w)
12574  .SetParam("center_crop", center_crop)
12575 (data)
12576  .CreateSymbol();
12577 }
12578 
12655  Symbol sequence_length,
12656  bool use_sequence_length = false,
12657  int axis = 0) {
12658  return Operator("SequenceReverse")
12659  .SetParam("use_sequence_length", use_sequence_length)
12660  .SetParam("axis", axis)
12661  .SetInput("data", data)
12662  .SetInput("sequence_length", sequence_length)
12663  .CreateSymbol();
12664 }
12665 
12676  Symbol loc,
12677  SpatialTransformerTransformType transform_type,
12678  SpatialTransformerSamplerType sampler_type,
12679  Shape target_shape = Shape(0,0)) {
12680  static const char *SpatialTransformerTransformTypeValues[] = {
12681  "affine"
12682  };
12683  static const char *SpatialTransformerSamplerTypeValues[] = {
12684  "bilinear"
12685  };
12686  return Operator("SpatialTransformer")
12687  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
12688  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
12689  .SetParam("target_shape", target_shape)
12690  .SetInput("data", data)
12691  .SetInput("loc", loc)
12692  .CreateSymbol();
12693 }
12694 
12750  Symbol sequence_length,
12751  bool use_sequence_length = false,
12752  int axis = 0) {
12753  return Operator("SequenceLast")
12754  .SetParam("use_sequence_length", use_sequence_length)
12755  .SetParam("axis", axis)
12756  .SetInput("data", data)
12757  .SetInput("sequence_length", sequence_length)
12758  .CreateSymbol();
12759 }
12760 
12856  Symbol label,
12857  mx_float grad_scale = 1,
12858  mx_float ignore_label = -1,
12859  bool multi_output = false,
12860  bool use_ignore = false,
12861  bool preserve_shape = false,
12863  bool out_grad = false,
12864  mx_float smooth_alpha = 0) {
12865  static const char *SoftmaxOutputNormalizationValues[] = {
12866  "batch",
12867  "null",
12868  "valid"
12869  };
12870  return Operator("SoftmaxOutput")
12871  .SetParam("grad_scale", grad_scale)
12872  .SetParam("ignore_label", ignore_label)
12873  .SetParam("multi_output", multi_output)
12874  .SetParam("use_ignore", use_ignore)
12875  .SetParam("preserve_shape", preserve_shape)
12876  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
12877  .SetParam("out_grad", out_grad)
12878  .SetParam("smooth_alpha", smooth_alpha)
12879  .SetInput("data", data)
12880  .SetInput("label", label)
12881  .CreateSymbol();
12882 }
12883 
12910 inline Symbol Softmax(Symbol data,
12911  mx_float grad_scale = 1,
12912  mx_float ignore_label = -1,
12913  bool multi_output = false,
12914  bool use_ignore = false,
12915  bool preserve_shape = false,
12917  bool out_grad = false,
12918  mx_float smooth_alpha = 0) {
12919  static const char *SoftmaxNormalizationValues[] = {
12920  "batch",
12921  "null",
12922  "valid"
12923  };
12924  return Operator("Softmax")
12925  .SetParam("grad_scale", grad_scale)
12926  .SetParam("ignore_label", ignore_label)
12927  .SetParam("multi_output", multi_output)
12928  .SetParam("use_ignore", use_ignore)
12929  .SetParam("preserve_shape", preserve_shape)
12930  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
12931  .SetParam("out_grad", out_grad)
12932  .SetParam("smooth_alpha", smooth_alpha)
12933  .SetInput("data", data)
12934  .CreateSymbol();
12935 }
12936 
13017  Symbol grid) {
13018  return Operator("BilinearSampler")
13019  .SetInput("data", data)
13020  .SetInput("grid", grid)
13021  .CreateSymbol();
13022 }
13023 
13080  Symbol rois,
13081  Shape pooled_size,
13082  mx_float spatial_scale) {
13083  return Operator("ROIPooling")
13084  .SetParam("pooled_size", pooled_size)
13085  .SetParam("spatial_scale", spatial_scale)
13086  .SetInput("data", data)
13087  .SetInput("rois", rois)
13088  .CreateSymbol();
13089 }
13090 
13153  mx_float eps = 1e-10,
13155  static const char *L2NormalizationModeValues[] = {
13156  "channel",
13157  "instance",
13158  "spatial"
13159  };
13160  return Operator("L2Normalization")
13161  .SetParam("eps", eps)
13162  .SetParam("mode", L2NormalizationModeValues[int(mode)])
13163  .SetInput("data", data)
13164  .CreateSymbol();
13165 }
13166 
13200 inline Symbol MakeLoss(Symbol data,
13201  mx_float grad_scale = 1,
13202  mx_float valid_thresh = 0,
13204  static const char *MakeLossNormalizationValues[] = {
13205  "batch",
13206  "null",
13207  "valid"
13208  };
13209  return Operator("MakeLoss")
13210  .SetParam("grad_scale", grad_scale)
13211  .SetParam("valid_thresh", valid_thresh)
13212  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
13213  .SetInput("data", data)
13214  .CreateSymbol();
13215 }
13216 
13232  Symbol label,
13233  mx_float margin = 1,
13234  mx_float regularization_coefficient = 1,
13235  bool use_linear = false) {
13236  return Operator("SVMOutput")
13237  .SetParam("margin", margin)
13238  .SetParam("regularization_coefficient", regularization_coefficient)
13239  .SetParam("use_linear", use_linear)
13240  .SetInput("data", data)
13241  .SetInput("label", label)
13242  .CreateSymbol();
13243 }
13244 
13271 inline Symbol LRN(Symbol data,
13272  uint32_t nsize,
13273  mx_float alpha = 0.0001,
13274  mx_float beta = 0.75,
13275  mx_float knorm = 2) {
13276  return Operator("LRN")
13277  .SetParam("nsize", nsize)
13278  .SetParam("alpha", alpha)
13279  .SetParam("beta", beta)
13280  .SetParam("knorm", knorm)
13281  .SetInput("data", data)
13282  .CreateSymbol();
13283 }
13284 
13333  Symbol data2,
13334  uint32_t kernel_size = 1,
13335  uint32_t max_displacement = 1,
13336  uint32_t stride1 = 1,
13337  uint32_t stride2 = 1,
13338  uint32_t pad_size = 0,
13339  bool is_multiply = true) {
13340  return Operator("Correlation")
13341  .SetParam("kernel_size", kernel_size)
13342  .SetParam("max_displacement", max_displacement)
13343  .SetParam("stride1", stride1)
13344  .SetParam("stride2", stride2)
13345  .SetParam("pad_size", pad_size)
13346  .SetParam("is_multiply", is_multiply)
13347  .SetInput("data1", data1)
13348  .SetInput("data2", data2)
13349  .CreateSymbol();
13350 }
13351 
13430  Symbol sequence_length,
13431  bool use_sequence_length = false,
13432  mx_float value = 0,
13433  int axis = 0) {
13434  return Operator("SequenceMask")
13435  .SetParam("use_sequence_length", use_sequence_length)
13436  .SetParam("value", value)
13437  .SetParam("axis", axis)
13438  .SetInput("data", data)
13439  .SetInput("sequence_length", sequence_length)
13440  .CreateSymbol();
13441 }
13442 
13451  Symbol rhs) {
13452  return Operator("choose_element_0index")
13453  .SetInput("lhs", lhs)
13454  .SetInput("rhs", rhs)
13455  .CreateSymbol();
13456 }
13457 
13467  Symbol mhs,
13468  Symbol rhs) {
13469  return Operator("fill_element_0index")
13470  .SetInput("lhs", lhs)
13471  .SetInput("mhs", mhs)
13472  .SetInput("rhs", rhs)
13473  .CreateSymbol();
13474 }
13475 
13476 } //namespace cpp
13477 } //namespace mxnet
13478 #endif // MXNET_CPP_OP_H_
Symbol Convolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, ConvolutionCudnnTune cudnn_tune=ConvolutionCudnnTune::kNone, bool cudnn_off=false, ConvolutionLayout layout=ConvolutionLayout::kNone)
Definition: op.h:5482
Symbol fix(const std::string &symbol_name, Symbol data)
Definition: op.h:1692
Symbol Crop(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, Shape offset=Shape(0, 0), Shape h_w=Shape(0, 0), bool center_crop=false)
Definition: op.h:6060
Symbol min(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2290
Symbol broadcast_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:894
Symbol arcsin(const std::string &symbol_name, Symbol data)
Definition: op.h:3422
Symbol FullyConnected(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, int num_hidden, bool no_bias=false, bool flatten=true)
Definition: op.h:5677
Symbol arccosh(const std::string &symbol_name, Symbol data)
Definition: op.h:3639
Symbol arctan(const std::string &symbol_name, Symbol data)
Definition: op.h:3474
Symbol SwapAxis(const std::string &symbol_name, Symbol data, uint32_t dim1=0, uint32_t dim2=0)
Definition: op.h:4230
Symbol cast_storage(const std::string &symbol_name, Symbol data, Cast_storageStype stype)
Definition: op.h:3309
Symbol nansum(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2177
Symbol add_n(const std::string &symbol_name, const std::vector< Symbol > &args)
Definition: op.h:985
Symbol log1p(const std::string &symbol_name, Symbol data)
Definition: op.h:1939
SoftmaxActivationMode
Definition: op.h:5593
Symbol mp_sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol weight32, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4717
Symbol SpatialTransformer(const std::string &symbol_name, Symbol data, Symbol loc, SpatialTransformerTransformType transform_type, SpatialTransformerSamplerType sampler_type, Shape target_shape=Shape(0, 0))
Definition: op.h:6186
Symbol slice(const std::string &symbol_name, Symbol data, Shape begin, Shape end, Shape step=Shape())
Definition: op.h:478
Symbol exp(const std::string &symbol_name, Symbol data)
Definition: op.h:1851
Symbol transpose(const std::string &symbol_name, Symbol data, Shape axes=Shape())
Definition: op.h:390
Symbol clip(const std::string &symbol_name, Symbol data, mx_float a_min, mx_float a_max)
Definition: op.h:570
Symbol elemwise_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2656
Symbol ROIPooling(const std::string &symbol_name, Symbol data, Symbol rois, Shape pooled_size, mx_float spatial_scale)
Definition: op.h:6616
Symbol broadcast_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:925
Symbol nanprod(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2216
Convolution_v1Layout
Definition: op.h:5958
Symbol argmin(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1065
Symbol SequenceReverse(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, int axis=0)
Definition: op.h:6151
Symbol mp_sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, Symbol weight32, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4752
Symbol broadcast_lesser(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3157
Symbol fill_element_0index(const std::string &symbol_name, Symbol lhs, Symbol mhs, Symbol rhs)
Definition: op.h:7037
Symbol Convolution_v1(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, Convolution_v1CudnnTune cudnn_tune=Convolution_v1CudnnTune::kNone, bool cudnn_off=false, Convolution_v1Layout layout=Convolution_v1Layout::kNone)
Definition: op.h:5994
Symbol broadcast_not_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3064
TakeMode
Definition: op.h:2752
Symbol Embedding(const std::string &symbol_name, Symbol data, Symbol weight, int input_dim, int output_dim, EmbeddingDtype dtype=EmbeddingDtype::kFloat32)
Definition: op.h:2726
Symbol SequenceLast(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, int axis=0)
Definition: op.h:6262
Symbol ftrl_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol z, Symbol n, mx_float lr, mx_float lamda1=0.01, mx_float beta=1, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:5105
Symbol reciprocal(const std::string &symbol_name, Symbol data)
Definition: op.h:1472
TopkRetTyp
Definition: op.h:2417
Symbol RNN(const std::string &symbol_name, Symbol data, Symbol parameters, Symbol state, Symbol state_cell, uint32_t state_size, uint32_t num_layers, RNNMode mode, bool bidirectional=false, mx_float p=0, bool state_outputs=false)
Definition: op.h:5908
namespace of mxnet
Definition: base.h:127
Symbol reshape_like(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1381
Pooling_v1PoolingConvention
Definition: op.h:5800
Symbol broadcast_lesser_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3188
Operator & SetInput(const std::string &name, Symbol symbol)
add an input symbol
Symbol InstanceNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001)
Definition: op.h:5744
Symbol sign(const std::string &symbol_name, Symbol data)
Definition: op.h:1524
GridGeneratorTransformType
Definition: op.h:5760
Cast_storageStype
Definition: op.h:3260
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
Symbol ones_like(const std::string &symbol_name, Symbol data)
Definition: op.h:793
RNNMode
Definition: op.h:5886
PadMode
Definition: op.h:3890
Symbol smooth_l1(const std::string &symbol_name, Symbol data, mx_float scalar)
Definition: op.h:3249
Symbol where(const std::string &symbol_name, Symbol condition, Symbol x, Symbol y)
Definition: op.h:3213
Symbol Dropout(const std::string &symbol_name, Symbol data, mx_float p=0.5, DropoutMode mode=DropoutMode::kTraining)
Definition: op.h:5575
Symbol expm1(const std::string &symbol_name, Symbol data)
Definition: op.h:1963
Symbol elemwise_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2585
PoolingPoolType
Definition: op.h:5132
Symbol relu(const std::string &symbol_name, Symbol data)
Definition: op.h:1269
Symbol reverse(const std::string &symbol_name, Symbol data, Shape axis)
Definition: op.h:703
Symbol rsqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1777
Symbol Pooling(const std::string &symbol_name, Symbol data, Shape kernel, PoolingPoolType pool_type, bool global_pool=false, bool cudnn_off=false, PoolingPoolingConvention pooling_convention=PoolingPoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:5198
Symbol Pooling_v1(const std::string &symbol_name, Symbol data, Shape kernel, Pooling_v1PoolType pool_type, bool global_pool=false, Pooling_v1PoolingConvention pooling_convention=Pooling_v1PoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:5856
SpatialTransformerTransformType
Definition: op.h:6166
ActivationActType
Definition: op.h:5327
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1751
Symbol Softmax(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxNormalization normalization=SoftmaxNormalization::kNull, bool out_grad=false, mx_float smooth_alpha=0)
Definition: op.h:6443
Symbol rint(const std::string &symbol_name, Symbol data)
Definition: op.h:1580
Symbol IdentityAttachKLSparseReg(const std::string &symbol_name, Symbol data, mx_float sparseness_target=0.1, mx_float penalty=0.001, mx_float momentum=0.9)
Definition: op.h:4479
Symbol sinh(const std::string &symbol_name, Symbol data)
Definition: op.h:3549
Symbol scatter_nd(const std::string &symbol_name, Symbol data, Symbol indices, Shape shape)
Definition: op.h:3000
Symbol broadcast_greater_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3126
Symbol LRN(const std::string &symbol_name, Symbol data, uint32_t nsize, mx_float alpha=0.0001, mx_float beta=0.75, mx_float knorm=2)
Definition: op.h:6834
Symbol sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4682
Symbol max(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2253
Symbol arcsinh(const std::string &symbol_name, Symbol data)
Definition: op.h:3619
Symbol sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4620
Symbol MAERegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:4423
Symbol SliceChannel(const std::string &symbol_name, Symbol data, int num_outputs, int axis=1, bool squeeze_axis=false)
Definition: op.h:4189
PoolingPoolingConvention
Definition: op.h:5140
Symbol broadcast_minimum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:180
Symbol broadcast_maximum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:147
Symbol Cast(const std::string &symbol_name, Symbol data, CastDtype dtype)
Definition: op.h:1419
DeconvolutionLayout
Definition: op.h:5239
Symbol trunc(const std::string &symbol_name, Symbol data)
Definition: op.h:1665
Pooling_v1PoolType
Definition: op.h:5792
Symbol round(const std::string &symbol_name, Symbol data)
Definition: op.h:1550
Symbol log_softmax(const std::string &symbol_name, Symbol data, int axis=-1)
Definition: op.h:3729
Symbol khatri_rao(const std::string &symbol_name, const std::vector< Symbol > &args)
Definition: op.h:62
Symbol cos(const std::string &symbol_name, Symbol data)
Definition: op.h:3367
Symbol L2Normalization(const std::string &symbol_name, Symbol data, mx_float eps=1e-10, L2NormalizationMode mode=L2NormalizationMode::kInstance)
Definition: op.h:6699
Symbol Correlation(const std::string &symbol_name, Symbol data1, Symbol data2, uint32_t kernel_size=1, uint32_t max_displacement=1, uint32_t stride1=1, uint32_t stride2=1, uint32_t pad_size=0, bool is_multiply=true)
Definition: op.h:6897
Symbol zeros_like(const std::string &symbol_name, Symbol data)
Definition: op.h:769
EmbeddingDtype
Definition: op.h:2667
Symbol batch_dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false)
Definition: op.h:1238
Symbol broadcast_mod(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:956
Symbol cbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1801
Symbol prod(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2138
operator helper functions
Symbol mean(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2101
Symbol tanh(const std::string &symbol_name, Symbol data)
Definition: op.h:3596
Symbol broadcast_to(const std::string &symbol_name, Symbol data, Shape shape=Shape())
Definition: op.h:2371
Symbol elemwise_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2609
DropoutMode
Definition: op.h:5531
Symbol MakeLoss(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float valid_thresh=0, MakeLossNormalization normalization=MakeLossNormalization::kNull)
Definition: op.h:6759
Symbol log(const std::string &symbol_name, Symbol data)
Definition: op.h:1872
Symbol sigmoid(const std::string &symbol_name, Symbol data)
Definition: op.h:1291
CastDtype
Definition: op.h:1392
ConvolutionLayout
Definition: op.h:5379
Symbol LogisticRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:4459
Symbol gamma(const std::string &symbol_name, Symbol data)
Definition: op.h:1981
Symbol sin(const std::string &symbol_name, Symbol data)
Definition: op.h:3343
UpSamplingMultiInputMode
Definition: op.h:3748
Symbol CreateSymbol(const std::string &name="")
create a Symbol from the current operator
Symbol elemwise_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2636
SpatialTransformerSamplerType
Definition: op.h:6172
Symbol Pad(const std::string &symbol_name, Symbol data, PadMode mode, Shape pad_width, double constant_value=0)
Definition: op.h:3992
Symbol square(const std::string &symbol_name, Symbol data)
Definition: op.h:1722
One_hotDtype
Definition: op.h:2854
UpSamplingSampleType
Definition: op.h:3740
Symbol norm(const std::string &symbol_name, Symbol data)
Definition: op.h:2405
Symbol rmsprop_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, mx_float lr, mx_float gamma1=0.95, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:4965
Symbol LeakyReLU(const std::string &symbol_name, Symbol data, LeakyReLUActType act_type=LeakyReLUActType::kLeaky, mx_float slope=0.25, mx_float lower_bound=0.125, mx_float upper_bound=0.334)
Definition: op.h:4097
Symbol make_loss(const std::string &symbol_name, Symbol data)
Definition: op.h:1367
Symbol SoftmaxActivation(const std::string &symbol_name, Symbol data, SoftmaxActivationMode mode=SoftmaxActivationMode::kInstance)
Definition: op.h:5631
Symbol broadcast_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3033
Symbol Deconvolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), Shape adj=Shape(), Shape target_shape=Shape(), uint32_t num_group=1, uint64_t workspace=512, bool no_bias=true, DeconvolutionCudnnTune cudnn_tune=DeconvolutionCudnnTune::kNone, bool cudnn_off=false, DeconvolutionLayout layout=DeconvolutionLayout::kNone)
Definition: op.h:5274
Symbol broadcast_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:827
Symbol adam_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mean, Symbol var, mx_float lr, mx_float beta1=0.9, mx_float beta2=0.999, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4885
Operator & SetParam(const std::string &name, const T &value)
set config parameters
Definition: operator.h:58
Symbol tan(const std::string &symbol_name, Symbol data)
Definition: op.h:3394
Convolution_v1CudnnTune
Definition: op.h:5948
Symbol repeat(const std::string &symbol_name, Symbol data, int repeats, dmlc::optional< int > axis=dmlc::optional< int >())
Definition: op.h:615
Symbol slice_axis(const std::string &symbol_name, Symbol data, int axis, int begin, dmlc::optional< int > end)
Definition: op.h:523
Symbol expand_dims(const std::string &symbol_name, Symbol data, int axis)
Definition: op.h:414
Symbol arctanh(const std::string &symbol_name, Symbol data)
Definition: op.h:3662
Symbol softmax_cross_entropy(const std::string &symbol_name, Symbol data, Symbol label)
Definition: op.h:4354
Symbol pick(const std::string &symbol_name, Symbol data, Symbol index, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1150
Symbol broadcast_axis(const std::string &symbol_name, Symbol data, Shape axis=Shape(), Shape size=Shape())
Definition: op.h:2332
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:1498
Symbol cosh(const std::string &symbol_name, Symbol data)
Definition: op.h:3571
Symbol sort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Definition: op.h:2518
Symbol gather_nd(const std::string &symbol_name, Symbol data, Symbol indices)
Definition: op.h:2955
Symbol BilinearSampler(const std::string &symbol_name, Symbol data, Symbol grid)
Definition: op.h:6551
Symbol Custom(const std::string &symbol_name, const std::vector< Symbol > &data, const std::string &op_type)
Definition: op.h:84
Symbol broadcast_hypot(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:219
Symbol BatchNorm_v1(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false)
Definition: op.h:4296
Symbol UpSampling(const std::string &symbol_name, const std::vector< Symbol > &data, uint32_t scale, UpSamplingSampleType sample_type, int num_args, uint32_t num_filter=0, UpSamplingMultiInputMode multi_input_mode=UpSamplingMultiInputMode::kConcat, uint64_t workspace=512)
Definition: op.h:3768
Symbol Activation(const std::string &symbol_name, Symbol data, ActivationActType act_type)
Definition: op.h:5352
float mx_float
manually define float
Definition: c_api.h:60
Symbol SVMOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float margin=1, mx_float regularization_coefficient=1, bool use_linear=false)
Definition: op.h:6792
Symbol radians(const std::string &symbol_name, Symbol data)
Definition: op.h:3524
Symbol Concat(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int dim=1)
Definition: op.h:4051
Symbol ftml_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol d, Symbol v, Symbol z, mx_float lr, mx_float beta1=0.9, mx_float beta2=0.999, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4809
L2NormalizationMode
Definition: op.h:6631
Symbol SequenceMask(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, mx_float value=0, int axis=0)
Definition: op.h:6996
Symbol stack(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int axis=0)
Definition: op.h:735
Symbol floor(const std::string &symbol_name, Symbol data)
Definition: op.h:1636
Symbol broadcast_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:863
Symbol take(const std::string &symbol_name, Symbol a, Symbol indices, int axis=0, TakeMode mode=TakeMode::kClip)
Definition: op.h:2797
Symbol ceil(const std::string &symbol_name, Symbol data)
Definition: op.h:1608
Symbol gammaln(const std::string &symbol_name, Symbol data)
Definition: op.h:1999
Symbol tile(const std::string &symbol_name, Symbol data, Shape reps)
Definition: op.h:671
Symbol signum_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float wd_lh=0)
Definition: op.h:4570
Symbol rmspropalex_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, Symbol g, Symbol delta, mx_float lr, mx_float gamma1=0.95, mx_float gamma2=0.9, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:5035
Symbol argsort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Definition: op.h:2559
SoftmaxNormalization
Definition: op.h:6410
Symbol softmax(const std::string &symbol_name, Symbol data, int axis=-1)
Definition: op.h:3698
DeconvolutionCudnnTune
Definition: op.h:5230
ConvolutionCudnnTune
Definition: op.h:5369
definition of shape
Symbol broadcast_greater(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3095
Symbol BatchNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, Symbol moving_mean, Symbol moving_var, double eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false, int axis=1, bool cudnn_off=false)
Definition: op.h:3858
Symbol rcbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1825
Symbol topk(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), int k=1, TopkRetTyp ret_typ=TopkRetTyp::kIndices, bool is_ascend=false)
Definition: op.h:2465
Symbol signsgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4519
Symbol broadcast_power(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:114
SoftmaxOutputNormalization
Definition: op.h:6277
Symbol SoftmaxOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxOutputNormalization normalization=SoftmaxOutputNormalization::kNull, bool out_grad=false, mx_float smooth_alpha=0)
Definition: op.h:6378
Symbol Flatten(const std::string &symbol_name, Symbol data)
Definition: op.h:347
Symbol BlockGrad(const std::string &symbol_name, Symbol data)
Definition: op.h:1331
LeakyReLUActType
Definition: op.h:4064
Symbol arccos(const std::string &symbol_name, Symbol data)
Definition: op.h:3447
Symbol argmax_channel(const std::string &symbol_name, Symbol data)
Definition: op.h:1098
Symbol batch_take(const std::string &symbol_name, Symbol a, Symbol indices)
Definition: op.h:2843
Symbol LinearRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:4387
Symbol choose_element_0index(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:7019
Symbol Reshape(const std::string &symbol_name, Symbol data, Shape shape=Shape(), bool reverse=false, Shape target_shape=Shape(), bool keep_highest=false)
Definition: op.h:302
Symbol degrees(const std::string &symbol_name, Symbol data)
Definition: op.h:3499
Symbol one_hot(const std::string &symbol_name, Symbol indices, int depth, double on_value=1, double off_value=0, One_hotDtype dtype=One_hotDtype::kFloat32)
Definition: op.h:2906
Symbol negative(const std::string &symbol_name, Symbol data)
Definition: op.h:1449
Symbol GridGenerator(const std::string &symbol_name, Symbol data, GridGeneratorTransformType transform_type, Shape target_shape=Shape(0, 0))
Definition: op.h:5775
Symbol argmax(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1023
Operator interface.
Definition: operator.h:43
Symbol interface.
Definition: symbol.h:72
Symbol sum(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2064
MakeLossNormalization
Definition: op.h:6719
Symbol log10(const std::string &symbol_name, Symbol data)
Definition: op.h:1893
Symbol dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false)
Definition: op.h:1203
Symbol log2(const std::string &symbol_name, Symbol data)
Definition: op.h:1914