mxnet
op.h
Go to the documentation of this file.
1 
8 #ifndef MXNET_CPP_OP_H_
9 #define MXNET_CPP_OP_H_
10 
11 #include <string>
12 #include <vector>
13 #include "mxnet-cpp/base.h"
14 #include "mxnet-cpp/shape.h"
15 #include "mxnet-cpp/op_util.h"
16 #include "mxnet-cpp/operator.h"
17 #include "dmlc/optional.h"
18 
19 namespace mxnet {
20 namespace cpp {
21 
44 inline Symbol broadcast_power(const std::string& symbol_name,
45  Symbol lhs,
46  Symbol rhs) {
47  return Operator("broadcast_power")
48  .SetInput("lhs", lhs)
49  .SetInput("rhs", rhs)
50  .CreateSymbol(symbol_name);
51 }
52 
77 inline Symbol broadcast_maximum(const std::string& symbol_name,
78  Symbol lhs,
79  Symbol rhs) {
80  return Operator("broadcast_maximum")
81  .SetInput("lhs", lhs)
82  .SetInput("rhs", rhs)
83  .CreateSymbol(symbol_name);
84 }
85 
110 inline Symbol broadcast_minimum(const std::string& symbol_name,
111  Symbol lhs,
112  Symbol rhs) {
113  return Operator("broadcast_minimum")
114  .SetInput("lhs", lhs)
115  .SetInput("rhs", rhs)
116  .CreateSymbol(symbol_name);
117 }
118 
149 inline Symbol broadcast_hypot(const std::string& symbol_name,
150  Symbol lhs,
151  Symbol rhs) {
152  return Operator("broadcast_hypot")
153  .SetInput("lhs", lhs)
154  .SetInput("rhs", rhs)
155  .CreateSymbol(symbol_name);
156 }
157 
232 inline Symbol Reshape(const std::string& symbol_name,
233  Symbol data,
234  Shape shape = Shape(),
235  bool reverse = false,
236  Shape target_shape = Shape(),
237  bool keep_highest = false) {
238  return Operator("Reshape")
239  .SetParam("shape", shape)
240  .SetParam("reverse", reverse)
241  .SetParam("target_shape", target_shape)
242  .SetParam("keep_highest", keep_highest)
243  .SetInput("data", data)
244  .CreateSymbol(symbol_name);
245 }
246 
277 inline Symbol Flatten(const std::string& symbol_name,
278  Symbol data) {
279  return Operator("Flatten")
280  .SetInput("data", data)
281  .CreateSymbol(symbol_name);
282 }
283 
320 inline Symbol transpose(const std::string& symbol_name,
321  Symbol data,
322  Shape axes = Shape()) {
323  return Operator("transpose")
324  .SetParam("axes", axes)
325  .SetInput("data", data)
326  .CreateSymbol(symbol_name);
327 }
328 
344 inline Symbol expand_dims(const std::string& symbol_name,
345  Symbol data,
346  int axis) {
347  return Operator("expand_dims")
348  .SetParam("axis", axis)
349  .SetInput("data", data)
350  .CreateSymbol(symbol_name);
351 }
352 
408 inline Symbol slice(const std::string& symbol_name,
409  Symbol data,
410  Shape begin,
411  Shape end,
412  Shape step = Shape()) {
413  return Operator("slice")
414  .SetParam("begin", begin)
415  .SetParam("end", end)
416  .SetParam("step", step)
417  .SetInput("data", data)
418  .CreateSymbol(symbol_name);
419 }
420 
453 inline Symbol slice_axis(const std::string& symbol_name,
454  Symbol data,
455  int axis,
456  int begin,
457  dmlc::optional<int> end) {
458  return Operator("slice_axis")
459  .SetParam("axis", axis)
460  .SetParam("begin", begin)
461  .SetParam("end", end)
462  .SetInput("data", data)
463  .CreateSymbol(symbol_name);
464 }
465 
500 inline Symbol clip(const std::string& symbol_name,
501  Symbol data,
502  mx_float a_min,
503  mx_float a_max) {
504  return Operator("clip")
505  .SetParam("a_min", a_min)
506  .SetParam("a_max", a_max)
507  .SetInput("data", data)
508  .CreateSymbol(symbol_name);
509 }
510 
545 inline Symbol repeat(const std::string& symbol_name,
546  Symbol data,
547  int repeats,
548  dmlc::optional<int> axis = dmlc::optional<int>()) {
549  return Operator("repeat")
550  .SetParam("repeats", repeats)
551  .SetParam("axis", axis)
552  .SetInput("data", data)
553  .CreateSymbol(symbol_name);
554 }
555 
601 inline Symbol tile(const std::string& symbol_name,
602  Symbol data,
603  Shape reps) {
604  return Operator("tile")
605  .SetParam("reps", reps)
606  .SetInput("data", data)
607  .CreateSymbol(symbol_name);
608 }
609 
633 inline Symbol reverse(const std::string& symbol_name,
634  Symbol data,
635  Shape axis) {
636  return Operator("reverse")
637  .SetParam("axis", axis)
638  .SetInput("data", data)
639  .CreateSymbol(symbol_name);
640 }
641 
665 inline Symbol stack(const std::string& symbol_name,
666  const std::vector<Symbol>& data,
667  int num_args,
668  int axis = 0) {
669  return Operator("stack")
670  .SetParam("num_args", num_args)
671  .SetParam("axis", axis)
672 (data)
673  .CreateSymbol(symbol_name);
674 }
675 
699 inline Symbol zeros_like(const std::string& symbol_name,
700  Symbol data) {
701  return Operator("zeros_like")
702  .SetInput("data", data)
703  .CreateSymbol(symbol_name);
704 }
705 
723 inline Symbol ones_like(const std::string& symbol_name,
724  Symbol data) {
725  return Operator("ones_like")
726  .SetInput("data", data)
727  .CreateSymbol(symbol_name);
728 }
729 
757 inline Symbol broadcast_add(const std::string& symbol_name,
758  Symbol lhs,
759  Symbol rhs) {
760  return Operator("broadcast_add")
761  .SetInput("lhs", lhs)
762  .SetInput("rhs", rhs)
763  .CreateSymbol(symbol_name);
764 }
765 
793 inline Symbol broadcast_sub(const std::string& symbol_name,
794  Symbol lhs,
795  Symbol rhs) {
796  return Operator("broadcast_sub")
797  .SetInput("lhs", lhs)
798  .SetInput("rhs", rhs)
799  .CreateSymbol(symbol_name);
800 }
801 
824 inline Symbol broadcast_mul(const std::string& symbol_name,
825  Symbol lhs,
826  Symbol rhs) {
827  return Operator("broadcast_mul")
828  .SetInput("lhs", lhs)
829  .SetInput("rhs", rhs)
830  .CreateSymbol(symbol_name);
831 }
832 
855 inline Symbol broadcast_div(const std::string& symbol_name,
856  Symbol lhs,
857  Symbol rhs) {
858  return Operator("broadcast_div")
859  .SetInput("lhs", lhs)
860  .SetInput("rhs", rhs)
861  .CreateSymbol(symbol_name);
862 }
863 
886 inline Symbol broadcast_mod(const std::string& symbol_name,
887  Symbol lhs,
888  Symbol rhs) {
889  return Operator("broadcast_mod")
890  .SetInput("lhs", lhs)
891  .SetInput("rhs", rhs)
892  .CreateSymbol(symbol_name);
893 }
894 
915 inline Symbol add_n(const std::string& symbol_name,
916  const std::vector<Symbol>& args) {
917  return Operator("add_n")
918 (args)
919  .CreateSymbol(symbol_name);
920 }
921 
953 inline Symbol argmax(const std::string& symbol_name,
954  Symbol data,
955  dmlc::optional<int> axis = dmlc::optional<int>(),
956  bool keepdims = false) {
957  return Operator("argmax")
958  .SetParam("axis", axis)
959  .SetParam("keepdims", keepdims)
960  .SetInput("data", data)
961  .CreateSymbol(symbol_name);
962 }
963 
995 inline Symbol argmin(const std::string& symbol_name,
996  Symbol data,
997  dmlc::optional<int> axis = dmlc::optional<int>(),
998  bool keepdims = false) {
999  return Operator("argmin")
1000  .SetParam("axis", axis)
1001  .SetParam("keepdims", keepdims)
1002  .SetInput("data", data)
1003  .CreateSymbol(symbol_name);
1004 }
1005 
1028 inline Symbol argmax_channel(const std::string& symbol_name,
1029  Symbol data) {
1030  return Operator("argmax_channel")
1031  .SetInput("data", data)
1032  .CreateSymbol(symbol_name);
1033 }
1034 
1080 inline Symbol pick(const std::string& symbol_name,
1081  Symbol data,
1082  Symbol index,
1083  dmlc::optional<int> axis = dmlc::optional<int>(),
1084  bool keepdims = false) {
1085  return Operator("pick")
1086  .SetParam("axis", axis)
1087  .SetParam("keepdims", keepdims)
1088  .SetInput("data", data)
1089  .SetInput("index", index)
1090  .CreateSymbol(symbol_name);
1091 }
1092 
1132 inline Symbol dot(const std::string& symbol_name,
1133  Symbol lhs,
1134  Symbol rhs,
1135  bool transpose_a = false,
1136  bool transpose_b = false) {
1137  return Operator("dot")
1138  .SetParam("transpose_a", transpose_a)
1139  .SetParam("transpose_b", transpose_b)
1140  .SetInput("lhs", lhs)
1141  .SetInput("rhs", rhs)
1142  .CreateSymbol(symbol_name);
1143 }
1144 
1167 inline Symbol batch_dot(const std::string& symbol_name,
1168  Symbol lhs,
1169  Symbol rhs,
1170  bool transpose_a = false,
1171  bool transpose_b = false) {
1172  return Operator("batch_dot")
1173  .SetParam("transpose_a", transpose_a)
1174  .SetParam("transpose_b", transpose_b)
1175  .SetInput("lhs", lhs)
1176  .SetInput("rhs", rhs)
1177  .CreateSymbol(symbol_name);
1178 }
1179 
1198 inline Symbol relu(const std::string& symbol_name,
1199  Symbol data) {
1200  return Operator("relu")
1201  .SetInput("data", data)
1202  .CreateSymbol(symbol_name);
1203 }
1204 
1220 inline Symbol sigmoid(const std::string& symbol_name,
1221  Symbol data) {
1222  return Operator("sigmoid")
1223  .SetInput("data", data)
1224  .CreateSymbol(symbol_name);
1225 }
1226 
1260 inline Symbol BlockGrad(const std::string& symbol_name,
1261  Symbol data) {
1262  return Operator("BlockGrad")
1263  .SetInput("data", data)
1264  .CreateSymbol(symbol_name);
1265 }
1266 
1296 inline Symbol make_loss(const std::string& symbol_name,
1297  Symbol data) {
1298  return Operator("make_loss")
1299  .SetInput("data", data)
1300  .CreateSymbol(symbol_name);
1301 }
1302 
1310 inline Symbol reshape_like(const std::string& symbol_name,
1311  Symbol lhs,
1312  Symbol rhs) {
1313  return Operator("reshape_like")
1314  .SetInput("lhs", lhs)
1315  .SetInput("rhs", rhs)
1316  .CreateSymbol(symbol_name);
1317 }
1318 
1321 enum class CastDtype {
1322  kFloat16 = 0,
1323  kFloat32 = 1,
1324  kFloat64 = 2,
1325  kInt32 = 3,
1326  kUint8 = 4
1327 };
1328 
1348 inline Symbol Cast(const std::string& symbol_name,
1349  Symbol data,
1350  CastDtype dtype) {
1351  static const char *CastDtypeValues[] = {
1352  "float16",
1353  "float32",
1354  "float64",
1355  "int32",
1356  "uint8"
1357  };
1358  return Operator("Cast")
1359  .SetParam("dtype", CastDtypeValues[int(dtype)])
1360  .SetInput("data", data)
1361  .CreateSymbol(symbol_name);
1362 }
1363 
1378 inline Symbol negative(const std::string& symbol_name,
1379  Symbol data) {
1380  return Operator("negative")
1381  .SetInput("data", data)
1382  .CreateSymbol(symbol_name);
1383 }
1384 
1401 inline Symbol reciprocal(const std::string& symbol_name,
1402  Symbol data) {
1403  return Operator("reciprocal")
1404  .SetInput("data", data)
1405  .CreateSymbol(symbol_name);
1406 }
1407 
1427 inline Symbol abs(const std::string& symbol_name,
1428  Symbol data) {
1429  return Operator("abs")
1430  .SetInput("data", data)
1431  .CreateSymbol(symbol_name);
1432 }
1433 
1453 inline Symbol sign(const std::string& symbol_name,
1454  Symbol data) {
1455  return Operator("sign")
1456  .SetInput("data", data)
1457  .CreateSymbol(symbol_name);
1458 }
1459 
1479 inline Symbol round(const std::string& symbol_name,
1480  Symbol data) {
1481  return Operator("round")
1482  .SetInput("data", data)
1483  .CreateSymbol(symbol_name);
1484 }
1485 
1509 inline Symbol rint(const std::string& symbol_name,
1510  Symbol data) {
1511  return Operator("rint")
1512  .SetInput("data", data)
1513  .CreateSymbol(symbol_name);
1514 }
1515 
1537 inline Symbol ceil(const std::string& symbol_name,
1538  Symbol data) {
1539  return Operator("ceil")
1540  .SetInput("data", data)
1541  .CreateSymbol(symbol_name);
1542 }
1543 
1565 inline Symbol floor(const std::string& symbol_name,
1566  Symbol data) {
1567  return Operator("floor")
1568  .SetInput("data", data)
1569  .CreateSymbol(symbol_name);
1570 }
1571 
1594 inline Symbol trunc(const std::string& symbol_name,
1595  Symbol data) {
1596  return Operator("trunc")
1597  .SetInput("data", data)
1598  .CreateSymbol(symbol_name);
1599 }
1600 
1621 inline Symbol fix(const std::string& symbol_name,
1622  Symbol data) {
1623  return Operator("fix")
1624  .SetInput("data", data)
1625  .CreateSymbol(symbol_name);
1626 }
1627 
1651 inline Symbol square(const std::string& symbol_name,
1652  Symbol data) {
1653  return Operator("square")
1654  .SetInput("data", data)
1655  .CreateSymbol(symbol_name);
1656 }
1657 
1680 inline Symbol sqrt(const std::string& symbol_name,
1681  Symbol data) {
1682  return Operator("sqrt")
1683  .SetInput("data", data)
1684  .CreateSymbol(symbol_name);
1685 }
1686 
1706 inline Symbol rsqrt(const std::string& symbol_name,
1707  Symbol data) {
1708  return Operator("rsqrt")
1709  .SetInput("data", data)
1710  .CreateSymbol(symbol_name);
1711 }
1712 
1730 inline Symbol cbrt(const std::string& symbol_name,
1731  Symbol data) {
1732  return Operator("cbrt")
1733  .SetInput("data", data)
1734  .CreateSymbol(symbol_name);
1735 }
1736 
1754 inline Symbol rcbrt(const std::string& symbol_name,
1755  Symbol data) {
1756  return Operator("rcbrt")
1757  .SetInput("data", data)
1758  .CreateSymbol(symbol_name);
1759 }
1760 
1780 inline Symbol exp(const std::string& symbol_name,
1781  Symbol data) {
1782  return Operator("exp")
1783  .SetInput("data", data)
1784  .CreateSymbol(symbol_name);
1785 }
1786 
1801 inline Symbol log(const std::string& symbol_name,
1802  Symbol data) {
1803  return Operator("log")
1804  .SetInput("data", data)
1805  .CreateSymbol(symbol_name);
1806 }
1807 
1822 inline Symbol log10(const std::string& symbol_name,
1823  Symbol data) {
1824  return Operator("log10")
1825  .SetInput("data", data)
1826  .CreateSymbol(symbol_name);
1827 }
1828 
1843 inline Symbol log2(const std::string& symbol_name,
1844  Symbol data) {
1845  return Operator("log2")
1846  .SetInput("data", data)
1847  .CreateSymbol(symbol_name);
1848 }
1849 
1868 inline Symbol log1p(const std::string& symbol_name,
1869  Symbol data) {
1870  return Operator("log1p")
1871  .SetInput("data", data)
1872  .CreateSymbol(symbol_name);
1873 }
1874 
1892 inline Symbol expm1(const std::string& symbol_name,
1893  Symbol data) {
1894  return Operator("expm1")
1895  .SetInput("data", data)
1896  .CreateSymbol(symbol_name);
1897 }
1898 
1910 inline Symbol gamma(const std::string& symbol_name,
1911  Symbol data) {
1912  return Operator("gamma")
1913  .SetInput("data", data)
1914  .CreateSymbol(symbol_name);
1915 }
1916 
1928 inline Symbol gammaln(const std::string& symbol_name,
1929  Symbol data) {
1930  return Operator("gammaln")
1931  .SetInput("data", data)
1932  .CreateSymbol(symbol_name);
1933 }
1934 
1993 inline Symbol sum(const std::string& symbol_name,
1994  Symbol data,
1995  Shape axis = Shape(),
1996  bool keepdims = false,
1997  bool exclude = false) {
1998  return Operator("sum")
1999  .SetParam("axis", axis)
2000  .SetParam("keepdims", keepdims)
2001  .SetParam("exclude", exclude)
2002  .SetInput("data", data)
2003  .CreateSymbol(symbol_name);
2004 }
2005 
2030 inline Symbol mean(const std::string& symbol_name,
2031  Symbol data,
2032  Shape axis = Shape(),
2033  bool keepdims = false,
2034  bool exclude = false) {
2035  return Operator("mean")
2036  .SetParam("axis", axis)
2037  .SetParam("keepdims", keepdims)
2038  .SetParam("exclude", exclude)
2039  .SetInput("data", data)
2040  .CreateSymbol(symbol_name);
2041 }
2042 
2067 inline Symbol prod(const std::string& symbol_name,
2068  Symbol data,
2069  Shape axis = Shape(),
2070  bool keepdims = false,
2071  bool exclude = false) {
2072  return Operator("prod")
2073  .SetParam("axis", axis)
2074  .SetParam("keepdims", keepdims)
2075  .SetParam("exclude", exclude)
2076  .SetInput("data", data)
2077  .CreateSymbol(symbol_name);
2078 }
2079 
2106 inline Symbol nansum(const std::string& symbol_name,
2107  Symbol data,
2108  Shape axis = Shape(),
2109  bool keepdims = false,
2110  bool exclude = false) {
2111  return Operator("nansum")
2112  .SetParam("axis", axis)
2113  .SetParam("keepdims", keepdims)
2114  .SetParam("exclude", exclude)
2115  .SetInput("data", data)
2116  .CreateSymbol(symbol_name);
2117 }
2118 
2145 inline Symbol nanprod(const std::string& symbol_name,
2146  Symbol data,
2147  Shape axis = Shape(),
2148  bool keepdims = false,
2149  bool exclude = false) {
2150  return Operator("nanprod")
2151  .SetParam("axis", axis)
2152  .SetParam("keepdims", keepdims)
2153  .SetParam("exclude", exclude)
2154  .SetInput("data", data)
2155  .CreateSymbol(symbol_name);
2156 }
2157 
2182 inline Symbol max(const std::string& symbol_name,
2183  Symbol data,
2184  Shape axis = Shape(),
2185  bool keepdims = false,
2186  bool exclude = false) {
2187  return Operator("max")
2188  .SetParam("axis", axis)
2189  .SetParam("keepdims", keepdims)
2190  .SetParam("exclude", exclude)
2191  .SetInput("data", data)
2192  .CreateSymbol(symbol_name);
2193 }
2194 
2219 inline Symbol min(const std::string& symbol_name,
2220  Symbol data,
2221  Shape axis = Shape(),
2222  bool keepdims = false,
2223  bool exclude = false) {
2224  return Operator("min")
2225  .SetParam("axis", axis)
2226  .SetParam("keepdims", keepdims)
2227  .SetParam("exclude", exclude)
2228  .SetInput("data", data)
2229  .CreateSymbol(symbol_name);
2230 }
2231 
2261 inline Symbol broadcast_axis(const std::string& symbol_name,
2262  Symbol data,
2263  Shape axis = Shape(),
2264  Shape size = Shape()) {
2265  return Operator("broadcast_axis")
2266  .SetParam("axis", axis)
2267  .SetParam("size", size)
2268  .SetInput("data", data)
2269  .CreateSymbol(symbol_name);
2270 }
2271 
2300 inline Symbol broadcast_to(const std::string& symbol_name,
2301  Symbol data,
2302  Shape shape = Shape()) {
2303  return Operator("broadcast_to")
2304  .SetParam("shape", shape)
2305  .SetInput("data", data)
2306  .CreateSymbol(symbol_name);
2307 }
2308 
2326 inline Symbol norm(const std::string& symbol_name,
2327  Symbol data) {
2328  return Operator("norm")
2329  .SetInput("data", data)
2330  .CreateSymbol(symbol_name);
2331 }
2332 
2338 enum class TopkRetTyp {
2339  kBoth = 0,
2340  kIndices = 1,
2341  kMask = 2,
2342  kValue = 3
2343 };
2344 
2386 inline Symbol topk(const std::string& symbol_name,
2387  Symbol data,
2388  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2389  int k = 1,
2390  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
2391  bool is_ascend = false) {
2392  static const char *TopkRetTypValues[] = {
2393  "both",
2394  "indices",
2395  "mask",
2396  "value"
2397  };
2398  return Operator("topk")
2399  .SetParam("axis", axis)
2400  .SetParam("k", k)
2401  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
2402  .SetParam("is_ascend", is_ascend)
2403  .SetInput("data", data)
2404  .CreateSymbol(symbol_name);
2405 }
2406 
2439 inline Symbol sort(const std::string& symbol_name,
2440  Symbol data,
2441  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2442  bool is_ascend = true) {
2443  return Operator("sort")
2444  .SetParam("axis", axis)
2445  .SetParam("is_ascend", is_ascend)
2446  .SetInput("data", data)
2447  .CreateSymbol(symbol_name);
2448 }
2449 
2480 inline Symbol argsort(const std::string& symbol_name,
2481  Symbol data,
2482  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2483  bool is_ascend = true) {
2484  return Operator("argsort")
2485  .SetParam("axis", axis)
2486  .SetParam("is_ascend", is_ascend)
2487  .SetInput("data", data)
2488  .CreateSymbol(symbol_name);
2489 }
2490 
2506 inline Symbol elemwise_add(const std::string& symbol_name,
2507  Symbol lhs,
2508  Symbol rhs) {
2509  return Operator("elemwise_add")
2510  .SetInput("lhs", lhs)
2511  .SetInput("rhs", rhs)
2512  .CreateSymbol(symbol_name);
2513 }
2514 
2530 inline Symbol elemwise_sub(const std::string& symbol_name,
2531  Symbol lhs,
2532  Symbol rhs) {
2533  return Operator("elemwise_sub")
2534  .SetInput("lhs", lhs)
2535  .SetInput("rhs", rhs)
2536  .CreateSymbol(symbol_name);
2537 }
2538 
2557 inline Symbol elemwise_mul(const std::string& symbol_name,
2558  Symbol lhs,
2559  Symbol rhs) {
2560  return Operator("elemwise_mul")
2561  .SetInput("lhs", lhs)
2562  .SetInput("rhs", rhs)
2563  .CreateSymbol(symbol_name);
2564 }
2565 
2577 inline Symbol elemwise_div(const std::string& symbol_name,
2578  Symbol lhs,
2579  Symbol rhs) {
2580  return Operator("elemwise_div")
2581  .SetInput("lhs", lhs)
2582  .SetInput("rhs", rhs)
2583  .CreateSymbol(symbol_name);
2584 }
2585 
2588 enum class EmbeddingDtype {
2589  kFloat16 = 0,
2590  kFloat32 = 1,
2591  kFloat64 = 2,
2592  kInt32 = 3,
2593  kUint8 = 4
2594 };
2595 
2647 inline Symbol Embedding(const std::string& symbol_name,
2648  Symbol data,
2649  Symbol weight,
2650  int input_dim,
2651  int output_dim,
2653  static const char *EmbeddingDtypeValues[] = {
2654  "float16",
2655  "float32",
2656  "float64",
2657  "int32",
2658  "uint8"
2659  };
2660  return Operator("Embedding")
2661  .SetParam("input_dim", input_dim)
2662  .SetParam("output_dim", output_dim)
2663  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
2664  .SetInput("data", data)
2665  .SetInput("weight", weight)
2666  .CreateSymbol(symbol_name);
2667 }
2668 
2673 enum class TakeMode {
2674  kClip = 0,
2675  kRaise = 1,
2676  kWrap = 2
2677 };
2678 
2718 inline Symbol take(const std::string& symbol_name,
2719  Symbol a,
2720  Symbol indices,
2721  int axis = 0,
2722  TakeMode mode = TakeMode::kClip) {
2723  static const char *TakeModeValues[] = {
2724  "clip",
2725  "raise",
2726  "wrap"
2727  };
2728  return Operator("take")
2729  .SetParam("axis", axis)
2730  .SetParam("mode", TakeModeValues[int(mode)])
2731  .SetInput("a", a)
2732  .SetInput("indices", indices)
2733  .CreateSymbol(symbol_name);
2734 }
2735 
2764 inline Symbol batch_take(const std::string& symbol_name,
2765  Symbol a,
2766  Symbol indices) {
2767  return Operator("batch_take")
2768  .SetInput("a", a)
2769  .SetInput("indices", indices)
2770  .CreateSymbol(symbol_name);
2771 }
2772 
2775 enum class One_hotDtype {
2776  kFloat16 = 0,
2777  kFloat32 = 1,
2778  kFloat64 = 2,
2779  kInt32 = 3,
2780  kUint8 = 4
2781 };
2782 
2827 inline Symbol one_hot(const std::string& symbol_name,
2828  Symbol indices,
2829  int depth,
2830  double on_value = 1,
2831  double off_value = 0,
2833  static const char *One_hotDtypeValues[] = {
2834  "float16",
2835  "float32",
2836  "float64",
2837  "int32",
2838  "uint8"
2839  };
2840  return Operator("one_hot")
2841  .SetParam("depth", depth)
2842  .SetParam("on_value", on_value)
2843  .SetParam("off_value", off_value)
2844  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
2845  .SetInput("indices", indices)
2846  .CreateSymbol(symbol_name);
2847 }
2848 
2876 inline Symbol gather_nd(const std::string& symbol_name,
2877  Symbol data,
2878  Symbol indices) {
2879  return Operator("gather_nd")
2880  .SetInput("data", data)
2881  .SetInput("indices", indices)
2882  .CreateSymbol(symbol_name);
2883 }
2884 
2921 inline Symbol scatter_nd(const std::string& symbol_name,
2922  Symbol data,
2923  Symbol indices,
2924  Shape shape) {
2925  return Operator("scatter_nd")
2926  .SetParam("shape", shape)
2927  .SetInput("data", data)
2928  .SetInput("indices", indices)
2929  .CreateSymbol(symbol_name);
2930 }
2931 
2954 inline Symbol broadcast_equal(const std::string& symbol_name,
2955  Symbol lhs,
2956  Symbol rhs) {
2957  return Operator("broadcast_equal")
2958  .SetInput("lhs", lhs)
2959  .SetInput("rhs", rhs)
2960  .CreateSymbol(symbol_name);
2961 }
2962 
2985 inline Symbol broadcast_not_equal(const std::string& symbol_name,
2986  Symbol lhs,
2987  Symbol rhs) {
2988  return Operator("broadcast_not_equal")
2989  .SetInput("lhs", lhs)
2990  .SetInput("rhs", rhs)
2991  .CreateSymbol(symbol_name);
2992 }
2993 
3016 inline Symbol broadcast_greater(const std::string& symbol_name,
3017  Symbol lhs,
3018  Symbol rhs) {
3019  return Operator("broadcast_greater")
3020  .SetInput("lhs", lhs)
3021  .SetInput("rhs", rhs)
3022  .CreateSymbol(symbol_name);
3023 }
3024 
3047 inline Symbol broadcast_greater_equal(const std::string& symbol_name,
3048  Symbol lhs,
3049  Symbol rhs) {
3050  return Operator("broadcast_greater_equal")
3051  .SetInput("lhs", lhs)
3052  .SetInput("rhs", rhs)
3053  .CreateSymbol(symbol_name);
3054 }
3055 
3078 inline Symbol broadcast_lesser(const std::string& symbol_name,
3079  Symbol lhs,
3080  Symbol rhs) {
3081  return Operator("broadcast_lesser")
3082  .SetInput("lhs", lhs)
3083  .SetInput("rhs", rhs)
3084  .CreateSymbol(symbol_name);
3085 }
3086 
3109 inline Symbol broadcast_lesser_equal(const std::string& symbol_name,
3110  Symbol lhs,
3111  Symbol rhs) {
3112  return Operator("broadcast_lesser_equal")
3113  .SetInput("lhs", lhs)
3114  .SetInput("rhs", rhs)
3115  .CreateSymbol(symbol_name);
3116 }
3117 
3134 inline Symbol where(const std::string& symbol_name,
3135  Symbol condition,
3136  Symbol x,
3137  Symbol y) {
3138  return Operator("where")
3139  .SetInput("condition", condition)
3140  .SetInput("x", x)
3141  .SetInput("y", y)
3142  .CreateSymbol(symbol_name);
3143 }
3144 
3170 inline Symbol smooth_l1(const std::string& symbol_name,
3171  Symbol data,
3172  mx_float scalar) {
3173  return Operator("smooth_l1")
3174  .SetParam("scalar", scalar)
3175  .SetInput("data", data)
3176  .CreateSymbol(symbol_name);
3177 }
3178 
3181 enum class Cast_storageStype {
3182  kCsr = 0,
3183  kDefault = 1,
3184  kRow_sparse = 2
3185 };
3186 
3230 inline Symbol cast_storage(const std::string& symbol_name,
3231  Symbol data,
3232  Cast_storageStype stype) {
3233  static const char *Cast_storageStypeValues[] = {
3234  "csr",
3235  "default",
3236  "row_sparse"
3237  };
3238  return Operator("cast_storage")
3239  .SetParam("stype", Cast_storageStypeValues[int(stype)])
3240  .SetInput("data", data)
3241  .CreateSymbol(symbol_name);
3242 }
3243 
3264 inline Symbol sin(const std::string& symbol_name,
3265  Symbol data) {
3266  return Operator("sin")
3267  .SetInput("data", data)
3268  .CreateSymbol(symbol_name);
3269 }
3270 
3288 inline Symbol cos(const std::string& symbol_name,
3289  Symbol data) {
3290  return Operator("cos")
3291  .SetInput("data", data)
3292  .CreateSymbol(symbol_name);
3293 }
3294 
3315 inline Symbol tan(const std::string& symbol_name,
3316  Symbol data) {
3317  return Operator("tan")
3318  .SetInput("data", data)
3319  .CreateSymbol(symbol_name);
3320 }
3321 
3343 inline Symbol arcsin(const std::string& symbol_name,
3344  Symbol data) {
3345  return Operator("arcsin")
3346  .SetInput("data", data)
3347  .CreateSymbol(symbol_name);
3348 }
3349 
3368 inline Symbol arccos(const std::string& symbol_name,
3369  Symbol data) {
3370  return Operator("arccos")
3371  .SetInput("data", data)
3372  .CreateSymbol(symbol_name);
3373 }
3374 
3395 inline Symbol arctan(const std::string& symbol_name,
3396  Symbol data) {
3397  return Operator("arctan")
3398  .SetInput("data", data)
3399  .CreateSymbol(symbol_name);
3400 }
3401 
3420 inline Symbol degrees(const std::string& symbol_name,
3421  Symbol data) {
3422  return Operator("degrees")
3423  .SetInput("data", data)
3424  .CreateSymbol(symbol_name);
3425 }
3426 
3445 inline Symbol radians(const std::string& symbol_name,
3446  Symbol data) {
3447  return Operator("radians")
3448  .SetInput("data", data)
3449  .CreateSymbol(symbol_name);
3450 }
3451 
3470 inline Symbol sinh(const std::string& symbol_name,
3471  Symbol data) {
3472  return Operator("sinh")
3473  .SetInput("data", data)
3474  .CreateSymbol(symbol_name);
3475 }
3476 
3492 inline Symbol cosh(const std::string& symbol_name,
3493  Symbol data) {
3494  return Operator("cosh")
3495  .SetInput("data", data)
3496  .CreateSymbol(symbol_name);
3497 }
3498 
3517 inline Symbol tanh(const std::string& symbol_name,
3518  Symbol data) {
3519  return Operator("tanh")
3520  .SetInput("data", data)
3521  .CreateSymbol(symbol_name);
3522 }
3523 
3540 inline Symbol arcsinh(const std::string& symbol_name,
3541  Symbol data) {
3542  return Operator("arcsinh")
3543  .SetInput("data", data)
3544  .CreateSymbol(symbol_name);
3545 }
3546 
3560 inline Symbol arccosh(const std::string& symbol_name,
3561  Symbol data) {
3562  return Operator("arccosh")
3563  .SetInput("data", data)
3564  .CreateSymbol(symbol_name);
3565 }
3566 
3583 inline Symbol arctanh(const std::string& symbol_name,
3584  Symbol data) {
3585  return Operator("arctanh")
3586  .SetInput("data", data)
3587  .CreateSymbol(symbol_name);
3588 }
3589 
3605 inline Symbol Custom(const std::string& symbol_name,
3606  const std::vector<Symbol>& data,
3607  const std::string& op_type) {
3608  return Operator("Custom")
3609 (data)
3610  .CreateSymbol(symbol_name);
3611 }
3612 
3642 inline Symbol softmax(const std::string& symbol_name,
3643  Symbol data,
3644  int axis = -1) {
3645  return Operator("softmax")
3646  .SetParam("axis", axis)
3647  .SetInput("data", data)
3648  .CreateSymbol(symbol_name);
3649 }
3650 
3673 inline Symbol log_softmax(const std::string& symbol_name,
3674  Symbol data,
3675  int axis = -1) {
3676  return Operator("log_softmax")
3677  .SetParam("axis", axis)
3678  .SetInput("data", data)
3679  .CreateSymbol(symbol_name);
3680 }
3681 
3684 enum class LeakyReLUActType {
3685  kElu = 0,
3686  kLeaky = 1,
3687  kPrelu = 2,
3688  kRrelu = 3
3689 };
3690 
3717 inline Symbol LeakyReLU(const std::string& symbol_name,
3718  Symbol data,
3720  mx_float slope = 0.25,
3721  mx_float lower_bound = 0.125,
3722  mx_float upper_bound = 0.334) {
3723  static const char *LeakyReLUActTypeValues[] = {
3724  "elu",
3725  "leaky",
3726  "prelu",
3727  "rrelu"
3728  };
3729  return Operator("LeakyReLU")
3730  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
3731  .SetParam("slope", slope)
3732  .SetParam("lower_bound", lower_bound)
3733  .SetParam("upper_bound", upper_bound)
3734  .SetInput("data", data)
3735  .CreateSymbol(symbol_name);
3736 }
3737 
3766 inline Symbol SwapAxis(const std::string& symbol_name,
3767  Symbol data,
3768  uint32_t dim1 = 0,
3769  uint32_t dim2 = 0) {
3770  return Operator("SwapAxis")
3771  .SetParam("dim1", dim1)
3772  .SetParam("dim2", dim2)
3773  .SetInput("data", data)
3774  .CreateSymbol(symbol_name);
3775 }
3776 
3832 inline Symbol BatchNorm_v1(const std::string& symbol_name,
3833  Symbol data,
3834  Symbol gamma,
3835  Symbol beta,
3836  mx_float eps = 0.001,
3837  mx_float momentum = 0.9,
3838  bool fix_gamma = true,
3839  bool use_global_stats = false,
3840  bool output_mean_var = false) {
3841  return Operator("BatchNorm_v1")
3842  .SetParam("eps", eps)
3843  .SetParam("momentum", momentum)
3844  .SetParam("fix_gamma", fix_gamma)
3845  .SetParam("use_global_stats", use_global_stats)
3846  .SetParam("output_mean_var", output_mean_var)
3847  .SetInput("data", data)
3848  .SetInput("gamma", gamma)
3849  .SetInput("beta", beta)
3850  .CreateSymbol(symbol_name);
3851 }
3852 
3894 inline Symbol Concat(const std::string& symbol_name,
3895  const std::vector<Symbol>& data,
3896  int num_args,
3897  int dim = 1) {
3898  return Operator("Concat")
3899  .SetParam("num_args", num_args)
3900  .SetParam("dim", dim)
3901 (data)
3902  .CreateSymbol(symbol_name);
3903 }
3904 
3932 inline Symbol sgd_update(const std::string& symbol_name,
3933  Symbol weight,
3934  Symbol grad,
3935  mx_float lr,
3936  mx_float wd = 0,
3937  mx_float rescale_grad = 1,
3938  mx_float clip_gradient = -1) {
3939  return Operator("sgd_update")
3940  .SetParam("lr", lr)
3941  .SetParam("wd", wd)
3942  .SetParam("rescale_grad", rescale_grad)
3943  .SetParam("clip_gradient", clip_gradient)
3944  .SetInput("weight", weight)
3945  .SetInput("grad", grad)
3946  .CreateSymbol(symbol_name);
3947 }
3948 
3991 inline Symbol sgd_mom_update(const std::string& symbol_name,
3992  Symbol weight,
3993  Symbol grad,
3994  Symbol mom,
3995  mx_float lr,
3996  mx_float momentum = 0,
3997  mx_float wd = 0,
3998  mx_float rescale_grad = 1,
3999  mx_float clip_gradient = -1) {
4000  return Operator("sgd_mom_update")
4001  .SetParam("lr", lr)
4002  .SetParam("momentum", momentum)
4003  .SetParam("wd", wd)
4004  .SetParam("rescale_grad", rescale_grad)
4005  .SetParam("clip_gradient", clip_gradient)
4006  .SetInput("weight", weight)
4007  .SetInput("grad", grad)
4008  .SetInput("mom", mom)
4009  .CreateSymbol(symbol_name);
4010 }
4011 
4026 inline Symbol mp_sgd_update(const std::string& symbol_name,
4027  Symbol weight,
4028  Symbol grad,
4029  Symbol weight32,
4030  mx_float lr,
4031  mx_float wd = 0,
4032  mx_float rescale_grad = 1,
4033  mx_float clip_gradient = -1) {
4034  return Operator("mp_sgd_update")
4035  .SetParam("lr", lr)
4036  .SetParam("wd", wd)
4037  .SetParam("rescale_grad", rescale_grad)
4038  .SetParam("clip_gradient", clip_gradient)
4039  .SetInput("weight", weight)
4040  .SetInput("grad", grad)
4041  .SetInput("weight32", weight32)
4042  .CreateSymbol(symbol_name);
4043 }
4044 
4061 inline Symbol mp_sgd_mom_update(const std::string& symbol_name,
4062  Symbol weight,
4063  Symbol grad,
4064  Symbol mom,
4065  Symbol weight32,
4066  mx_float lr,
4067  mx_float momentum = 0,
4068  mx_float wd = 0,
4069  mx_float rescale_grad = 1,
4070  mx_float clip_gradient = -1) {
4071  return Operator("mp_sgd_mom_update")
4072  .SetParam("lr", lr)
4073  .SetParam("momentum", momentum)
4074  .SetParam("wd", wd)
4075  .SetParam("rescale_grad", rescale_grad)
4076  .SetParam("clip_gradient", clip_gradient)
4077  .SetInput("weight", weight)
4078  .SetInput("grad", grad)
4079  .SetInput("mom", mom)
4080  .SetInput("weight32", weight32)
4081  .CreateSymbol(symbol_name);
4082 }
4083 
4131 inline Symbol adam_update(const std::string& symbol_name,
4132  Symbol weight,
4133  Symbol grad,
4134  Symbol mean,
4135  Symbol var,
4136  mx_float lr,
4137  mx_float beta1 = 0.9,
4138  mx_float beta2 = 0.999,
4139  mx_float epsilon = 1e-08,
4140  mx_float wd = 0,
4141  mx_float rescale_grad = 1,
4142  mx_float clip_gradient = -1) {
4143  return Operator("adam_update")
4144  .SetParam("lr", lr)
4145  .SetParam("beta1", beta1)
4146  .SetParam("beta2", beta2)
4147  .SetParam("epsilon", epsilon)
4148  .SetParam("wd", wd)
4149  .SetParam("rescale_grad", rescale_grad)
4150  .SetParam("clip_gradient", clip_gradient)
4151  .SetInput("weight", weight)
4152  .SetInput("grad", grad)
4153  .SetInput("mean", mean)
4154  .SetInput("var", var)
4155  .CreateSymbol(symbol_name);
4156 }
4157 
4211 inline Symbol rmsprop_update(const std::string& symbol_name,
4212  Symbol weight,
4213  Symbol grad,
4214  Symbol n,
4215  mx_float lr,
4216  mx_float gamma1 = 0.95,
4217  mx_float epsilon = 1e-08,
4218  mx_float wd = 0,
4219  mx_float rescale_grad = 1,
4220  mx_float clip_gradient = -1,
4221  mx_float clip_weights = -1) {
4222  return Operator("rmsprop_update")
4223  .SetParam("lr", lr)
4224  .SetParam("gamma1", gamma1)
4225  .SetParam("epsilon", epsilon)
4226  .SetParam("wd", wd)
4227  .SetParam("rescale_grad", rescale_grad)
4228  .SetParam("clip_gradient", clip_gradient)
4229  .SetParam("clip_weights", clip_weights)
4230  .SetInput("weight", weight)
4231  .SetInput("grad", grad)
4232  .SetInput("n", n)
4233  .CreateSymbol(symbol_name);
4234 }
4235 
4281 inline Symbol rmspropalex_update(const std::string& symbol_name,
4282  Symbol weight,
4283  Symbol grad,
4284  Symbol n,
4285  Symbol g,
4286  Symbol delta,
4287  mx_float lr,
4288  mx_float gamma1 = 0.95,
4289  mx_float gamma2 = 0.9,
4290  mx_float epsilon = 1e-08,
4291  mx_float wd = 0,
4292  mx_float rescale_grad = 1,
4293  mx_float clip_gradient = -1,
4294  mx_float clip_weights = -1) {
4295  return Operator("rmspropalex_update")
4296  .SetParam("lr", lr)
4297  .SetParam("gamma1", gamma1)
4298  .SetParam("gamma2", gamma2)
4299  .SetParam("epsilon", epsilon)
4300  .SetParam("wd", wd)
4301  .SetParam("rescale_grad", rescale_grad)
4302  .SetParam("clip_gradient", clip_gradient)
4303  .SetParam("clip_weights", clip_weights)
4304  .SetInput("weight", weight)
4305  .SetInput("grad", grad)
4306  .SetInput("n", n)
4307  .SetInput("g", g)
4308  .SetInput("delta", delta)
4309  .CreateSymbol(symbol_name);
4310 }
4311 
4351 inline Symbol ftrl_update(const std::string& symbol_name,
4352  Symbol weight,
4353  Symbol grad,
4354  Symbol z,
4355  Symbol n,
4356  mx_float lr,
4357  mx_float lamda1 = 0.01,
4358  mx_float beta = 1,
4359  mx_float wd = 0,
4360  mx_float rescale_grad = 1,
4361  mx_float clip_gradient = -1) {
4362  return Operator("ftrl_update")
4363  .SetParam("lr", lr)
4364  .SetParam("lamda1", lamda1)
4365  .SetParam("beta", beta)
4366  .SetParam("wd", wd)
4367  .SetParam("rescale_grad", rescale_grad)
4368  .SetParam("clip_gradient", clip_gradient)
4369  .SetInput("weight", weight)
4370  .SetInput("grad", grad)
4371  .SetInput("z", z)
4372  .SetInput("n", n)
4373  .CreateSymbol(symbol_name);
4374 }
4375 
4379 enum class PadMode {
4380  kConstant = 0,
4381  kEdge = 1,
4382  kReflect = 2
4383 };
4384 
4481 inline Symbol Pad(const std::string& symbol_name,
4482  Symbol data,
4483  PadMode mode,
4484  Shape pad_width,
4485  double constant_value = 0) {
4486  static const char *PadModeValues[] = {
4487  "constant",
4488  "edge",
4489  "reflect"
4490  };
4491  return Operator("Pad")
4492  .SetParam("mode", PadModeValues[int(mode)])
4493  .SetParam("pad_width", pad_width)
4494  .SetParam("constant_value", constant_value)
4495  .SetInput("data", data)
4496  .CreateSymbol(symbol_name);
4497 }
4498 
4508 inline Symbol IdentityAttachKLSparseReg(const std::string& symbol_name,
4509  Symbol data,
4510  mx_float sparseness_target = 0.1,
4511  mx_float penalty = 0.001,
4512  mx_float momentum = 0.9) {
4513  return Operator("IdentityAttachKLSparseReg")
4514  .SetParam("sparseness_target", sparseness_target)
4515  .SetParam("penalty", penalty)
4516  .SetParam("momentum", momentum)
4517  .SetInput("data", data)
4518  .CreateSymbol(symbol_name);
4519 }
4520 
4592 inline Symbol SliceChannel(const std::string& symbol_name,
4593  Symbol data,
4594  int num_outputs,
4595  int axis = 1,
4596  bool squeeze_axis = false) {
4597  return Operator("SliceChannel")
4598  .SetParam("num_outputs", num_outputs)
4599  .SetParam("axis", axis)
4600  .SetParam("squeeze_axis", squeeze_axis)
4601  .SetInput("data", data)
4602  .CreateSymbol(symbol_name);
4603 }
4604 
4642 inline Symbol softmax_cross_entropy(const std::string& symbol_name,
4643  Symbol data,
4644  Symbol label) {
4645  return Operator("softmax_cross_entropy")
4646  .SetInput("data", data)
4647  .SetInput("label", label)
4648  .CreateSymbol(symbol_name);
4649 }
4650 
4654  kBilinear = 0,
4655  kNearest = 1
4656 };
4657 
4662  kConcat = 0,
4663  kSum = 1
4664 };
4665 
4681 inline Symbol UpSampling(const std::string& symbol_name,
4682  const std::vector<Symbol>& data,
4683  uint32_t scale,
4684  UpSamplingSampleType sample_type,
4685  int num_args,
4686  uint32_t num_filter = 0,
4688  uint64_t workspace = 512) {
4689  static const char *UpSamplingSampleTypeValues[] = {
4690  "bilinear",
4691  "nearest"
4692  };
4693  static const char *UpSamplingMultiInputModeValues[] = {
4694  "concat",
4695  "sum"
4696  };
4697  return Operator("UpSampling")
4698  .SetParam("scale", scale)
4699  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
4700  .SetParam("num_args", num_args)
4701  .SetParam("num_filter", num_filter)
4702  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
4703  .SetParam("workspace", workspace)
4704 (data)
4705  .CreateSymbol(symbol_name);
4706 }
4707 
4771 inline Symbol BatchNorm(const std::string& symbol_name,
4772  Symbol data,
4773  Symbol gamma,
4774  Symbol beta,
4775  Symbol moving_mean,
4776  Symbol moving_var,
4777  double eps = 0.001,
4778  mx_float momentum = 0.9,
4779  bool fix_gamma = true,
4780  bool use_global_stats = false,
4781  bool output_mean_var = false,
4782  int axis = 1,
4783  bool cudnn_off = false) {
4784  return Operator("BatchNorm")
4785  .SetParam("eps", eps)
4786  .SetParam("momentum", momentum)
4787  .SetParam("fix_gamma", fix_gamma)
4788  .SetParam("use_global_stats", use_global_stats)
4789  .SetParam("output_mean_var", output_mean_var)
4790  .SetParam("axis", axis)
4791  .SetParam("cudnn_off", cudnn_off)
4792  .SetInput("data", data)
4793  .SetInput("gamma", gamma)
4794  .SetInput("beta", beta)
4795  .SetInput("moving_mean", moving_mean)
4796  .SetInput("moving_var", moving_var)
4797  .CreateSymbol(symbol_name);
4798 }
4799 
4850 inline Symbol InstanceNorm(const std::string& symbol_name,
4851  Symbol data,
4852  Symbol gamma,
4853  Symbol beta,
4854  mx_float eps = 0.001) {
4855  return Operator("InstanceNorm")
4856  .SetParam("eps", eps)
4857  .SetInput("data", data)
4858  .SetInput("gamma", gamma)
4859  .SetInput("beta", beta)
4860  .CreateSymbol(symbol_name);
4861 }
4862 
4865 enum class RNNMode {
4866  kGru = 0,
4867  kLstm = 1,
4868  kRnn_relu = 2,
4869  kRnn_tanh = 3
4870 };
4871 
4887 inline Symbol RNN(const std::string& symbol_name,
4888  Symbol data,
4889  Symbol parameters,
4890  Symbol state,
4891  Symbol state_cell,
4892  uint32_t state_size,
4893  uint32_t num_layers,
4894  RNNMode mode,
4895  bool bidirectional = false,
4896  mx_float p = 0,
4897  bool state_outputs = false) {
4898  static const char *RNNModeValues[] = {
4899  "gru",
4900  "lstm",
4901  "rnn_relu",
4902  "rnn_tanh"
4903  };
4904  return Operator("RNN")
4905  .SetParam("state_size", state_size)
4906  .SetParam("num_layers", num_layers)
4907  .SetParam("mode", RNNModeValues[int(mode)])
4908  .SetParam("bidirectional", bidirectional)
4909  .SetParam("p", p)
4910  .SetParam("state_outputs", state_outputs)
4911  .SetInput("data", data)
4912  .SetInput("parameters", parameters)
4913  .SetInput("state", state)
4914  .SetInput("state_cell", state_cell)
4915  .CreateSymbol(symbol_name);
4916 }
4917 
4928  kNone = 0,
4929  kFastest = 1,
4930  kLimited_workspace = 2,
4931  kOff = 3
4932 };
4933 
4938  kNone = 0,
4939  kNCDHW = 1,
4940  kNCHW = 2,
4941  kNDHWC = 3,
4942  kNHWC = 4
4943 };
4944 
4973 inline Symbol Convolution_v1(const std::string& symbol_name,
4974  Symbol data,
4975  Symbol weight,
4976  Symbol bias,
4977  Shape kernel,
4978  uint32_t num_filter,
4979  Shape stride = Shape(),
4980  Shape dilate = Shape(),
4981  Shape pad = Shape(),
4982  uint32_t num_group = 1,
4983  uint64_t workspace = 1024,
4984  bool no_bias = false,
4986  bool cudnn_off = false,
4988  static const char *Convolution_v1CudnnTuneValues[] = {
4989  "None",
4990  "fastest",
4991  "limited_workspace",
4992  "off"
4993  };
4994  static const char *Convolution_v1LayoutValues[] = {
4995  "None",
4996  "NCDHW",
4997  "NCHW",
4998  "NDHWC",
4999  "NHWC"
5000  };
5001  return Operator("Convolution_v1")
5002  .SetParam("kernel", kernel)
5003  .SetParam("num_filter", num_filter)
5004  .SetParam("stride", stride)
5005  .SetParam("dilate", dilate)
5006  .SetParam("pad", pad)
5007  .SetParam("num_group", num_group)
5008  .SetParam("workspace", workspace)
5009  .SetParam("no_bias", no_bias)
5010  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
5011  .SetParam("cudnn_off", cudnn_off)
5012  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
5013  .SetInput("data", data)
5014  .SetInput("weight", weight)
5015  .SetInput("bias", bias)
5016  .CreateSymbol(symbol_name);
5017 }
5018 
5039 inline Symbol Crop(const std::string& symbol_name,
5040  const std::vector<Symbol>& data,
5041  int num_args,
5042  Shape offset = Shape(0,0),
5043  Shape h_w = Shape(0,0),
5044  bool center_crop = false) {
5045  return Operator("Crop")
5046  .SetParam("num_args", num_args)
5047  .SetParam("offset", offset)
5048  .SetParam("h_w", h_w)
5049  .SetParam("center_crop", center_crop)
5050 (data)
5051  .CreateSymbol(symbol_name);
5052 }
5053 
5057  kAffine = 0
5058 };
5059 
5063  kBilinear = 0
5064 };
5065 
5076 inline Symbol SpatialTransformer(const std::string& symbol_name,
5077  Symbol data,
5078  Symbol loc,
5079  SpatialTransformerTransformType transform_type,
5080  SpatialTransformerSamplerType sampler_type,
5081  Shape target_shape = Shape(0,0)) {
5082  static const char *SpatialTransformerTransformTypeValues[] = {
5083  "affine"
5084  };
5085  static const char *SpatialTransformerSamplerTypeValues[] = {
5086  "bilinear"
5087  };
5088  return Operator("SpatialTransformer")
5089  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
5090  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
5091  .SetParam("target_shape", target_shape)
5092  .SetInput("data", data)
5093  .SetInput("loc", loc)
5094  .CreateSymbol(symbol_name);
5095 }
5096 
5100  kNone = 0,
5101  kFastest = 1,
5102  kLimited_workspace = 2,
5103  kOff = 3
5104 };
5105 
5109  kNone = 0,
5110  kNCDHW = 1,
5111  kNCHW = 2,
5112  kNCW = 3,
5113  kNDHWC = 4,
5114  kNHWC = 5
5115 };
5116 
5143 inline Symbol Deconvolution(const std::string& symbol_name,
5144  Symbol data,
5145  Symbol weight,
5146  Symbol bias,
5147  Shape kernel,
5148  uint32_t num_filter,
5149  Shape stride = Shape(),
5150  Shape dilate = Shape(),
5151  Shape pad = Shape(),
5152  Shape adj = Shape(),
5153  Shape target_shape = Shape(),
5154  uint32_t num_group = 1,
5155  uint64_t workspace = 512,
5156  bool no_bias = true,
5158  bool cudnn_off = false,
5160  static const char *DeconvolutionCudnnTuneValues[] = {
5161  "None",
5162  "fastest",
5163  "limited_workspace",
5164  "off"
5165  };
5166  static const char *DeconvolutionLayoutValues[] = {
5167  "None",
5168  "NCDHW",
5169  "NCHW",
5170  "NCW",
5171  "NDHWC",
5172  "NHWC"
5173  };
5174  return Operator("Deconvolution")
5175  .SetParam("kernel", kernel)
5176  .SetParam("num_filter", num_filter)
5177  .SetParam("stride", stride)
5178  .SetParam("dilate", dilate)
5179  .SetParam("pad", pad)
5180  .SetParam("adj", adj)
5181  .SetParam("target_shape", target_shape)
5182  .SetParam("num_group", num_group)
5183  .SetParam("workspace", workspace)
5184  .SetParam("no_bias", no_bias)
5185  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
5186  .SetParam("cudnn_off", cudnn_off)
5187  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
5188  .SetInput("data", data)
5189  .SetInput("weight", weight)
5190  .SetInput("bias", bias)
5191  .CreateSymbol(symbol_name);
5192 }
5193 
5197  kBatch = 0,
5198  kNull = 1,
5199  kValid = 2
5200 };
5201 
5297 inline Symbol SoftmaxOutput(const std::string& symbol_name,
5298  Symbol data,
5299  Symbol label,
5300  mx_float grad_scale = 1,
5301  mx_float ignore_label = -1,
5302  bool multi_output = false,
5303  bool use_ignore = false,
5304  bool preserve_shape = false,
5306  bool out_grad = false,
5307  mx_float smooth_alpha = 0) {
5308  static const char *SoftmaxOutputNormalizationValues[] = {
5309  "batch",
5310  "null",
5311  "valid"
5312  };
5313  return Operator("SoftmaxOutput")
5314  .SetParam("grad_scale", grad_scale)
5315  .SetParam("ignore_label", ignore_label)
5316  .SetParam("multi_output", multi_output)
5317  .SetParam("use_ignore", use_ignore)
5318  .SetParam("preserve_shape", preserve_shape)
5319  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
5320  .SetParam("out_grad", out_grad)
5321  .SetParam("smooth_alpha", smooth_alpha)
5322  .SetInput("data", data)
5323  .SetInput("label", label)
5324  .CreateSymbol(symbol_name);
5325 }
5326 
5330  kBatch = 0,
5331  kNull = 1,
5332  kValid = 2
5333 };
5334 
5362 inline Symbol Softmax(const std::string& symbol_name,
5363  Symbol data,
5364  mx_float grad_scale = 1,
5365  mx_float ignore_label = -1,
5366  bool multi_output = false,
5367  bool use_ignore = false,
5368  bool preserve_shape = false,
5370  bool out_grad = false,
5371  mx_float smooth_alpha = 0) {
5372  static const char *SoftmaxNormalizationValues[] = {
5373  "batch",
5374  "null",
5375  "valid"
5376  };
5377  return Operator("Softmax")
5378  .SetParam("grad_scale", grad_scale)
5379  .SetParam("ignore_label", ignore_label)
5380  .SetParam("multi_output", multi_output)
5381  .SetParam("use_ignore", use_ignore)
5382  .SetParam("preserve_shape", preserve_shape)
5383  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
5384  .SetParam("out_grad", out_grad)
5385  .SetParam("smooth_alpha", smooth_alpha)
5386  .SetInput("data", data)
5387  .CreateSymbol(symbol_name);
5388 }
5389 
5465 inline Symbol SequenceReverse(const std::string& symbol_name,
5466  Symbol data,
5467  Symbol sequence_length,
5468  bool use_sequence_length = false) {
5469  return Operator("SequenceReverse")
5470  .SetParam("use_sequence_length", use_sequence_length)
5471  .SetInput("data", data)
5472  .SetInput("sequence_length", sequence_length)
5473  .CreateSymbol(symbol_name);
5474 }
5475 
5530 inline Symbol SequenceLast(const std::string& symbol_name,
5531  Symbol data,
5532  Symbol sequence_length,
5533  bool use_sequence_length = false) {
5534  return Operator("SequenceLast")
5535  .SetParam("use_sequence_length", use_sequence_length)
5536  .SetInput("data", data)
5537  .SetInput("sequence_length", sequence_length)
5538  .CreateSymbol(symbol_name);
5539 }
5540 
5589 inline Symbol Correlation(const std::string& symbol_name,
5590  Symbol data1,
5591  Symbol data2,
5592  uint32_t kernel_size = 1,
5593  uint32_t max_displacement = 1,
5594  uint32_t stride1 = 1,
5595  uint32_t stride2 = 1,
5596  uint32_t pad_size = 0,
5597  bool is_multiply = true) {
5598  return Operator("Correlation")
5599  .SetParam("kernel_size", kernel_size)
5600  .SetParam("max_displacement", max_displacement)
5601  .SetParam("stride1", stride1)
5602  .SetParam("stride2", stride2)
5603  .SetParam("pad_size", pad_size)
5604  .SetParam("is_multiply", is_multiply)
5605  .SetInput("data1", data1)
5606  .SetInput("data2", data2)
5607  .CreateSymbol(symbol_name);
5608 }
5609 
5625 inline Symbol SVMOutput(const std::string& symbol_name,
5626  Symbol data,
5627  Symbol label,
5628  mx_float margin = 1,
5629  mx_float regularization_coefficient = 1,
5630  bool use_linear = false) {
5631  return Operator("SVMOutput")
5632  .SetParam("margin", margin)
5633  .SetParam("regularization_coefficient", regularization_coefficient)
5634  .SetParam("use_linear", use_linear)
5635  .SetInput("data", data)
5636  .SetInput("label", label)
5637  .CreateSymbol(symbol_name);
5638 }
5639 
5643  kChannel = 0,
5644  kInstance = 1,
5645  kSpatial = 2
5646 };
5647 
5710 inline Symbol L2Normalization(const std::string& symbol_name,
5711  Symbol data,
5712  mx_float eps = 1e-10,
5714  static const char *L2NormalizationModeValues[] = {
5715  "channel",
5716  "instance",
5717  "spatial"
5718  };
5719  return Operator("L2Normalization")
5720  .SetParam("eps", eps)
5721  .SetParam("mode", L2NormalizationModeValues[int(mode)])
5722  .SetInput("data", data)
5723  .CreateSymbol(symbol_name);
5724 }
5725 
5753 inline Symbol LRN(const std::string& symbol_name,
5754  Symbol data,
5755  uint32_t nsize,
5756  mx_float alpha = 0.0001,
5757  mx_float beta = 0.75,
5758  mx_float knorm = 2) {
5759  return Operator("LRN")
5760  .SetParam("nsize", nsize)
5761  .SetParam("alpha", alpha)
5762  .SetParam("beta", beta)
5763  .SetParam("knorm", knorm)
5764  .SetInput("data", data)
5765  .CreateSymbol(symbol_name);
5766 }
5767 
5801 inline Symbol FullyConnected(const std::string& symbol_name,
5802  Symbol data,
5803  Symbol weight,
5804  Symbol bias,
5805  int num_hidden,
5806  bool no_bias = false,
5807  bool flatten = true) {
5808  return Operator("FullyConnected")
5809  .SetParam("num_hidden", num_hidden)
5810  .SetParam("no_bias", no_bias)
5811  .SetParam("flatten", flatten)
5812  .SetInput("data", data)
5813  .SetInput("weight", weight)
5814  .SetInput("bias", bias)
5815  .CreateSymbol(symbol_name);
5816 }
5817 
5895 inline Symbol SequenceMask(const std::string& symbol_name,
5896  Symbol data,
5897  Symbol sequence_length,
5898  bool use_sequence_length = false,
5899  mx_float value = 0) {
5900  return Operator("SequenceMask")
5901  .SetParam("use_sequence_length", use_sequence_length)
5902  .SetParam("value", value)
5903  .SetInput("data", data)
5904  .SetInput("sequence_length", sequence_length)
5905  .CreateSymbol(symbol_name);
5906 }
5907 
5912  kAffine = 0,
5913  kWarp = 1
5914 };
5915 
5926 inline Symbol GridGenerator(const std::string& symbol_name,
5927  Symbol data,
5928  GridGeneratorTransformType transform_type,
5929  Shape target_shape = Shape(0,0)) {
5930  static const char *GridGeneratorTransformTypeValues[] = {
5931  "affine",
5932  "warp"
5933  };
5934  return Operator("GridGenerator")
5935  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
5936  .SetParam("target_shape", target_shape)
5937  .SetInput("data", data)
5938  .CreateSymbol(symbol_name);
5939 }
5940 
5944  kAvg = 0,
5945  kMax = 1,
5946  kSum = 2
5947 };
5948 
5952  kFull = 0,
5953  kValid = 1
5954 };
5955 
6007 inline Symbol Pooling_v1(const std::string& symbol_name,
6008  Symbol data,
6009  Shape kernel,
6010  Pooling_v1PoolType pool_type,
6011  bool global_pool = false,
6013  Shape stride = Shape(),
6014  Shape pad = Shape()) {
6015  static const char *Pooling_v1PoolTypeValues[] = {
6016  "avg",
6017  "max",
6018  "sum"
6019  };
6020  static const char *Pooling_v1PoolingConventionValues[] = {
6021  "full",
6022  "valid"
6023  };
6024  return Operator("Pooling_v1")
6025  .SetParam("kernel", kernel)
6026  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
6027  .SetParam("global_pool", global_pool)
6028  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
6029  .SetParam("stride", stride)
6030  .SetParam("pad", pad)
6031  .SetInput("data", data)
6032  .CreateSymbol(symbol_name);
6033 }
6034 
6038  kNone = 0,
6039  kFastest = 1,
6040  kLimited_workspace = 2,
6041  kOff = 3
6042 };
6043 
6047 enum class ConvolutionLayout {
6048  kNone = 0,
6049  kNCDHW = 1,
6050  kNCHW = 2,
6051  kNCW = 3,
6052  kNDHWC = 4,
6053  kNHWC = 5
6054 };
6055 
6150 inline Symbol Convolution(const std::string& symbol_name,
6151  Symbol data,
6152  Symbol weight,
6153  Symbol bias,
6154  Shape kernel,
6155  uint32_t num_filter,
6156  Shape stride = Shape(),
6157  Shape dilate = Shape(),
6158  Shape pad = Shape(),
6159  uint32_t num_group = 1,
6160  uint64_t workspace = 1024,
6161  bool no_bias = false,
6163  bool cudnn_off = false,
6165  static const char *ConvolutionCudnnTuneValues[] = {
6166  "None",
6167  "fastest",
6168  "limited_workspace",
6169  "off"
6170  };
6171  static const char *ConvolutionLayoutValues[] = {
6172  "None",
6173  "NCDHW",
6174  "NCHW",
6175  "NCW",
6176  "NDHWC",
6177  "NHWC"
6178  };
6179  return Operator("Convolution")
6180  .SetParam("kernel", kernel)
6181  .SetParam("num_filter", num_filter)
6182  .SetParam("stride", stride)
6183  .SetParam("dilate", dilate)
6184  .SetParam("pad", pad)
6185  .SetParam("num_group", num_group)
6186  .SetParam("workspace", workspace)
6187  .SetParam("no_bias", no_bias)
6188  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
6189  .SetParam("cudnn_off", cudnn_off)
6190  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
6191  .SetInput("data", data)
6192  .SetInput("weight", weight)
6193  .SetInput("bias", bias)
6194  .CreateSymbol(symbol_name);
6195 }
6196 
6277 inline Symbol BilinearSampler(const std::string& symbol_name,
6278  Symbol data,
6279  Symbol grid) {
6280  return Operator("BilinearSampler")
6281  .SetInput("data", data)
6282  .SetInput("grid", grid)
6283  .CreateSymbol(symbol_name);
6284 }
6285 
6288 enum class PoolingPoolType {
6289  kAvg = 0,
6290  kMax = 1,
6291  kSum = 2
6292 };
6293 
6297  kFull = 0,
6298  kValid = 1
6299 };
6300 
6354 inline Symbol Pooling(const std::string& symbol_name,
6355  Symbol data,
6356  Shape kernel,
6357  PoolingPoolType pool_type,
6358  bool global_pool = false,
6359  bool cudnn_off = false,
6361  Shape stride = Shape(),
6362  Shape pad = Shape()) {
6363  static const char *PoolingPoolTypeValues[] = {
6364  "avg",
6365  "max",
6366  "sum"
6367  };
6368  static const char *PoolingPoolingConventionValues[] = {
6369  "full",
6370  "valid"
6371  };
6372  return Operator("Pooling")
6373  .SetParam("kernel", kernel)
6374  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
6375  .SetParam("global_pool", global_pool)
6376  .SetParam("cudnn_off", cudnn_off)
6377  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
6378  .SetParam("stride", stride)
6379  .SetParam("pad", pad)
6380  .SetInput("data", data)
6381  .CreateSymbol(symbol_name);
6382 }
6383 
6386 enum class DropoutMode {
6387  kAlways = 0,
6388  kTraining = 1
6389 };
6390 
6430 inline Symbol Dropout(const std::string& symbol_name,
6431  Symbol data,
6432  mx_float p = 0.5,
6434  static const char *DropoutModeValues[] = {
6435  "always",
6436  "training"
6437  };
6438  return Operator("Dropout")
6439  .SetParam("p", p)
6440  .SetParam("mode", DropoutModeValues[int(mode)])
6441  .SetInput("data", data)
6442  .CreateSymbol(symbol_name);
6443 }
6444 
6447 enum class ActivationActType {
6448  kRelu = 0,
6449  kSigmoid = 1,
6450  kSoftrelu = 2,
6451  kTanh = 3
6452 };
6453 
6472 inline Symbol Activation(const std::string& symbol_name,
6473  Symbol data,
6474  ActivationActType act_type) {
6475  static const char *ActivationActTypeValues[] = {
6476  "relu",
6477  "sigmoid",
6478  "softrelu",
6479  "tanh"
6480  };
6481  return Operator("Activation")
6482  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
6483  .SetInput("data", data)
6484  .CreateSymbol(symbol_name);
6485 }
6486 
6543 inline Symbol ROIPooling(const std::string& symbol_name,
6544  Symbol data,
6545  Symbol rois,
6546  Shape pooled_size,
6547  mx_float spatial_scale) {
6548  return Operator("ROIPooling")
6549  .SetParam("pooled_size", pooled_size)
6550  .SetParam("spatial_scale", spatial_scale)
6551  .SetInput("data", data)
6552  .SetInput("rois", rois)
6553  .CreateSymbol(symbol_name);
6554 }
6555 
6580 inline Symbol LinearRegressionOutput(const std::string& symbol_name,
6581  Symbol data,
6582  Symbol label,
6583  mx_float grad_scale = 1) {
6584  return Operator("LinearRegressionOutput")
6585  .SetParam("grad_scale", grad_scale)
6586  .SetInput("data", data)
6587  .SetInput("label", label)
6588  .CreateSymbol(symbol_name);
6589 }
6590 
6616 inline Symbol MAERegressionOutput(const std::string& symbol_name,
6617  Symbol data,
6618  Symbol label,
6619  mx_float grad_scale = 1) {
6620  return Operator("MAERegressionOutput")
6621  .SetParam("grad_scale", grad_scale)
6622  .SetInput("data", data)
6623  .SetInput("label", label)
6624  .CreateSymbol(symbol_name);
6625 }
6626 
6652 inline Symbol LogisticRegressionOutput(const std::string& symbol_name,
6653  Symbol data,
6654  Symbol label,
6655  mx_float grad_scale = 1) {
6656  return Operator("LogisticRegressionOutput")
6657  .SetParam("grad_scale", grad_scale)
6658  .SetInput("data", data)
6659  .SetInput("label", label)
6660  .CreateSymbol(symbol_name);
6661 }
6662 
6667  kChannel = 0,
6668  kInstance = 1
6669 };
6670 
6704 inline Symbol SoftmaxActivation(const std::string& symbol_name,
6705  Symbol data,
6707  static const char *SoftmaxActivationModeValues[] = {
6708  "channel",
6709  "instance"
6710  };
6711  return Operator("SoftmaxActivation")
6712  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
6713  .SetInput("data", data)
6714  .CreateSymbol(symbol_name);
6715 }
6716 
6722  kBatch = 0,
6723  kNull = 1,
6724  kValid = 2
6725 };
6726 
6761 inline Symbol MakeLoss(const std::string& symbol_name,
6762  Symbol data,
6763  mx_float grad_scale = 1,
6764  mx_float valid_thresh = 0,
6766  static const char *MakeLossNormalizationValues[] = {
6767  "batch",
6768  "null",
6769  "valid"
6770  };
6771  return Operator("MakeLoss")
6772  .SetParam("grad_scale", grad_scale)
6773  .SetParam("valid_thresh", valid_thresh)
6774  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
6775  .SetInput("data", data)
6776  .CreateSymbol(symbol_name);
6777 }
6778 
6787 inline Symbol choose_element_0index(const std::string& symbol_name,
6788  Symbol lhs,
6789  Symbol rhs) {
6790  return Operator("choose_element_0index")
6791  .SetInput("lhs", lhs)
6792  .SetInput("rhs", rhs)
6793  .CreateSymbol(symbol_name);
6794 }
6795 
6805 inline Symbol fill_element_0index(const std::string& symbol_name,
6806  Symbol lhs,
6807  Symbol mhs,
6808  Symbol rhs) {
6809  return Operator("fill_element_0index")
6810  .SetInput("lhs", lhs)
6811  .SetInput("mhs", mhs)
6812  .SetInput("rhs", rhs)
6813  .CreateSymbol(symbol_name);
6814 }
6815 
6838  Symbol rhs) {
6839  return Operator("broadcast_power")
6840  .SetInput("lhs", lhs)
6841  .SetInput("rhs", rhs)
6842  .CreateSymbol();
6843 }
6844 
6869  Symbol rhs) {
6870  return Operator("broadcast_maximum")
6871  .SetInput("lhs", lhs)
6872  .SetInput("rhs", rhs)
6873  .CreateSymbol();
6874 }
6875 
6900  Symbol rhs) {
6901  return Operator("broadcast_minimum")
6902  .SetInput("lhs", lhs)
6903  .SetInput("rhs", rhs)
6904  .CreateSymbol();
6905 }
6906 
6937  Symbol rhs) {
6938  return Operator("broadcast_hypot")
6939  .SetInput("lhs", lhs)
6940  .SetInput("rhs", rhs)
6941  .CreateSymbol();
6942 }
6943 
7017 inline Symbol Reshape(Symbol data,
7018  Shape shape = Shape(),
7019  bool reverse = false,
7020  Shape target_shape = Shape(),
7021  bool keep_highest = false) {
7022  return Operator("Reshape")
7023  .SetParam("shape", shape)
7024  .SetParam("reverse", reverse)
7025  .SetParam("target_shape", target_shape)
7026  .SetParam("keep_highest", keep_highest)
7027  .SetInput("data", data)
7028  .CreateSymbol();
7029 }
7030 
7060 inline Symbol Flatten(Symbol data) {
7061  return Operator("Flatten")
7062  .SetInput("data", data)
7063  .CreateSymbol();
7064 }
7065 
7102  Shape axes = Shape()) {
7103  return Operator("transpose")
7104  .SetParam("axes", axes)
7105  .SetInput("data", data)
7106  .CreateSymbol();
7107 }
7108 
7124  int axis) {
7125  return Operator("expand_dims")
7126  .SetParam("axis", axis)
7127  .SetInput("data", data)
7128  .CreateSymbol();
7129 }
7130 
7185 inline Symbol slice(Symbol data,
7186  Shape begin,
7187  Shape end,
7188  Shape step = Shape()) {
7189  return Operator("slice")
7190  .SetParam("begin", begin)
7191  .SetParam("end", end)
7192  .SetParam("step", step)
7193  .SetInput("data", data)
7194  .CreateSymbol();
7195 }
7196 
7229  int axis,
7230  int begin,
7231  dmlc::optional<int> end) {
7232  return Operator("slice_axis")
7233  .SetParam("axis", axis)
7234  .SetParam("begin", begin)
7235  .SetParam("end", end)
7236  .SetInput("data", data)
7237  .CreateSymbol();
7238 }
7239 
7273 inline Symbol clip(Symbol data,
7274  mx_float a_min,
7275  mx_float a_max) {
7276  return Operator("clip")
7277  .SetParam("a_min", a_min)
7278  .SetParam("a_max", a_max)
7279  .SetInput("data", data)
7280  .CreateSymbol();
7281 }
7282 
7316 inline Symbol repeat(Symbol data,
7317  int repeats,
7318  dmlc::optional<int> axis = dmlc::optional<int>()) {
7319  return Operator("repeat")
7320  .SetParam("repeats", repeats)
7321  .SetParam("axis", axis)
7322  .SetInput("data", data)
7323  .CreateSymbol();
7324 }
7325 
7370 inline Symbol tile(Symbol data,
7371  Shape reps) {
7372  return Operator("tile")
7373  .SetParam("reps", reps)
7374  .SetInput("data", data)
7375  .CreateSymbol();
7376 }
7377 
7400 inline Symbol reverse(Symbol data,
7401  Shape axis) {
7402  return Operator("reverse")
7403  .SetParam("axis", axis)
7404  .SetInput("data", data)
7405  .CreateSymbol();
7406 }
7407 
7430 inline Symbol stack(const std::vector<Symbol>& data,
7431  int num_args,
7432  int axis = 0) {
7433  return Operator("stack")
7434  .SetParam("num_args", num_args)
7435  .SetParam("axis", axis)
7436 (data)
7437  .CreateSymbol();
7438 }
7439 
7462 inline Symbol zeros_like(Symbol data) {
7463  return Operator("zeros_like")
7464  .SetInput("data", data)
7465  .CreateSymbol();
7466 }
7467 
7484 inline Symbol ones_like(Symbol data) {
7485  return Operator("ones_like")
7486  .SetInput("data", data)
7487  .CreateSymbol();
7488 }
7489 
7517  Symbol rhs) {
7518  return Operator("broadcast_add")
7519  .SetInput("lhs", lhs)
7520  .SetInput("rhs", rhs)
7521  .CreateSymbol();
7522 }
7523 
7551  Symbol rhs) {
7552  return Operator("broadcast_sub")
7553  .SetInput("lhs", lhs)
7554  .SetInput("rhs", rhs)
7555  .CreateSymbol();
7556 }
7557 
7580  Symbol rhs) {
7581  return Operator("broadcast_mul")
7582  .SetInput("lhs", lhs)
7583  .SetInput("rhs", rhs)
7584  .CreateSymbol();
7585 }
7586 
7609  Symbol rhs) {
7610  return Operator("broadcast_div")
7611  .SetInput("lhs", lhs)
7612  .SetInput("rhs", rhs)
7613  .CreateSymbol();
7614 }
7615 
7638  Symbol rhs) {
7639  return Operator("broadcast_mod")
7640  .SetInput("lhs", lhs)
7641  .SetInput("rhs", rhs)
7642  .CreateSymbol();
7643 }
7644 
7664 inline Symbol add_n(const std::vector<Symbol>& args) {
7665  return Operator("add_n")
7666 (args)
7667  .CreateSymbol();
7668 }
7669 
7700 inline Symbol argmax(Symbol data,
7701  dmlc::optional<int> axis = dmlc::optional<int>(),
7702  bool keepdims = false) {
7703  return Operator("argmax")
7704  .SetParam("axis", axis)
7705  .SetParam("keepdims", keepdims)
7706  .SetInput("data", data)
7707  .CreateSymbol();
7708 }
7709 
7740 inline Symbol argmin(Symbol data,
7741  dmlc::optional<int> axis = dmlc::optional<int>(),
7742  bool keepdims = false) {
7743  return Operator("argmin")
7744  .SetParam("axis", axis)
7745  .SetParam("keepdims", keepdims)
7746  .SetInput("data", data)
7747  .CreateSymbol();
7748 }
7749 
7772  return Operator("argmax_channel")
7773  .SetInput("data", data)
7774  .CreateSymbol();
7775 }
7776 
7821 inline Symbol pick(Symbol data,
7822  Symbol index,
7823  dmlc::optional<int> axis = dmlc::optional<int>(),
7824  bool keepdims = false) {
7825  return Operator("pick")
7826  .SetParam("axis", axis)
7827  .SetParam("keepdims", keepdims)
7828  .SetInput("data", data)
7829  .SetInput("index", index)
7830  .CreateSymbol();
7831 }
7832 
7871 inline Symbol dot(Symbol lhs,
7872  Symbol rhs,
7873  bool transpose_a = false,
7874  bool transpose_b = false) {
7875  return Operator("dot")
7876  .SetParam("transpose_a", transpose_a)
7877  .SetParam("transpose_b", transpose_b)
7878  .SetInput("lhs", lhs)
7879  .SetInput("rhs", rhs)
7880  .CreateSymbol();
7881 }
7882 
7905  Symbol rhs,
7906  bool transpose_a = false,
7907  bool transpose_b = false) {
7908  return Operator("batch_dot")
7909  .SetParam("transpose_a", transpose_a)
7910  .SetParam("transpose_b", transpose_b)
7911  .SetInput("lhs", lhs)
7912  .SetInput("rhs", rhs)
7913  .CreateSymbol();
7914 }
7915 
7933 inline Symbol relu(Symbol data) {
7934  return Operator("relu")
7935  .SetInput("data", data)
7936  .CreateSymbol();
7937 }
7938 
7953 inline Symbol sigmoid(Symbol data) {
7954  return Operator("sigmoid")
7955  .SetInput("data", data)
7956  .CreateSymbol();
7957 }
7958 
7991 inline Symbol BlockGrad(Symbol data) {
7992  return Operator("BlockGrad")
7993  .SetInput("data", data)
7994  .CreateSymbol();
7995 }
7996 
8025 inline Symbol make_loss(Symbol data) {
8026  return Operator("make_loss")
8027  .SetInput("data", data)
8028  .CreateSymbol();
8029 }
8030 
8038  Symbol rhs) {
8039  return Operator("reshape_like")
8040  .SetInput("lhs", lhs)
8041  .SetInput("rhs", rhs)
8042  .CreateSymbol();
8043 }
8044 
8063 inline Symbol Cast(Symbol data,
8064  CastDtype dtype) {
8065  static const char *CastDtypeValues[] = {
8066  "float16",
8067  "float32",
8068  "float64",
8069  "int32",
8070  "uint8"
8071  };
8072  return Operator("Cast")
8073  .SetParam("dtype", CastDtypeValues[int(dtype)])
8074  .SetInput("data", data)
8075  .CreateSymbol();
8076 }
8077 
8091 inline Symbol negative(Symbol data) {
8092  return Operator("negative")
8093  .SetInput("data", data)
8094  .CreateSymbol();
8095 }
8096 
8112 inline Symbol reciprocal(Symbol data) {
8113  return Operator("reciprocal")
8114  .SetInput("data", data)
8115  .CreateSymbol();
8116 }
8117 
8136 inline Symbol abs(Symbol data) {
8137  return Operator("abs")
8138  .SetInput("data", data)
8139  .CreateSymbol();
8140 }
8141 
8160 inline Symbol sign(Symbol data) {
8161  return Operator("sign")
8162  .SetInput("data", data)
8163  .CreateSymbol();
8164 }
8165 
8184 inline Symbol round(Symbol data) {
8185  return Operator("round")
8186  .SetInput("data", data)
8187  .CreateSymbol();
8188 }
8189 
8212 inline Symbol rint(Symbol data) {
8213  return Operator("rint")
8214  .SetInput("data", data)
8215  .CreateSymbol();
8216 }
8217 
8238 inline Symbol ceil(Symbol data) {
8239  return Operator("ceil")
8240  .SetInput("data", data)
8241  .CreateSymbol();
8242 }
8243 
8264 inline Symbol floor(Symbol data) {
8265  return Operator("floor")
8266  .SetInput("data", data)
8267  .CreateSymbol();
8268 }
8269 
8291 inline Symbol trunc(Symbol data) {
8292  return Operator("trunc")
8293  .SetInput("data", data)
8294  .CreateSymbol();
8295 }
8296 
8316 inline Symbol fix(Symbol data) {
8317  return Operator("fix")
8318  .SetInput("data", data)
8319  .CreateSymbol();
8320 }
8321 
8344 inline Symbol square(Symbol data) {
8345  return Operator("square")
8346  .SetInput("data", data)
8347  .CreateSymbol();
8348 }
8349 
8371 inline Symbol sqrt(Symbol data) {
8372  return Operator("sqrt")
8373  .SetInput("data", data)
8374  .CreateSymbol();
8375 }
8376 
8395 inline Symbol rsqrt(Symbol data) {
8396  return Operator("rsqrt")
8397  .SetInput("data", data)
8398  .CreateSymbol();
8399 }
8400 
8417 inline Symbol cbrt(Symbol data) {
8418  return Operator("cbrt")
8419  .SetInput("data", data)
8420  .CreateSymbol();
8421 }
8422 
8439 inline Symbol rcbrt(Symbol data) {
8440  return Operator("rcbrt")
8441  .SetInput("data", data)
8442  .CreateSymbol();
8443 }
8444 
8463 inline Symbol exp(Symbol data) {
8464  return Operator("exp")
8465  .SetInput("data", data)
8466  .CreateSymbol();
8467 }
8468 
8482 inline Symbol log(Symbol data) {
8483  return Operator("log")
8484  .SetInput("data", data)
8485  .CreateSymbol();
8486 }
8487 
8501 inline Symbol log10(Symbol data) {
8502  return Operator("log10")
8503  .SetInput("data", data)
8504  .CreateSymbol();
8505 }
8506 
8520 inline Symbol log2(Symbol data) {
8521  return Operator("log2")
8522  .SetInput("data", data)
8523  .CreateSymbol();
8524 }
8525 
8543 inline Symbol log1p(Symbol data) {
8544  return Operator("log1p")
8545  .SetInput("data", data)
8546  .CreateSymbol();
8547 }
8548 
8565 inline Symbol expm1(Symbol data) {
8566  return Operator("expm1")
8567  .SetInput("data", data)
8568  .CreateSymbol();
8569 }
8570 
8581 inline Symbol gamma(Symbol data) {
8582  return Operator("gamma")
8583  .SetInput("data", data)
8584  .CreateSymbol();
8585 }
8586 
8597 inline Symbol gammaln(Symbol data) {
8598  return Operator("gammaln")
8599  .SetInput("data", data)
8600  .CreateSymbol();
8601 }
8602 
8660 inline Symbol sum(Symbol data,
8661  Shape axis = Shape(),
8662  bool keepdims = false,
8663  bool exclude = false) {
8664  return Operator("sum")
8665  .SetParam("axis", axis)
8666  .SetParam("keepdims", keepdims)
8667  .SetParam("exclude", exclude)
8668  .SetInput("data", data)
8669  .CreateSymbol();
8670 }
8671 
8695 inline Symbol mean(Symbol data,
8696  Shape axis = Shape(),
8697  bool keepdims = false,
8698  bool exclude = false) {
8699  return Operator("mean")
8700  .SetParam("axis", axis)
8701  .SetParam("keepdims", keepdims)
8702  .SetParam("exclude", exclude)
8703  .SetInput("data", data)
8704  .CreateSymbol();
8705 }
8706 
8730 inline Symbol prod(Symbol data,
8731  Shape axis = Shape(),
8732  bool keepdims = false,
8733  bool exclude = false) {
8734  return Operator("prod")
8735  .SetParam("axis", axis)
8736  .SetParam("keepdims", keepdims)
8737  .SetParam("exclude", exclude)
8738  .SetInput("data", data)
8739  .CreateSymbol();
8740 }
8741 
8767 inline Symbol nansum(Symbol data,
8768  Shape axis = Shape(),
8769  bool keepdims = false,
8770  bool exclude = false) {
8771  return Operator("nansum")
8772  .SetParam("axis", axis)
8773  .SetParam("keepdims", keepdims)
8774  .SetParam("exclude", exclude)
8775  .SetInput("data", data)
8776  .CreateSymbol();
8777 }
8778 
8804 inline Symbol nanprod(Symbol data,
8805  Shape axis = Shape(),
8806  bool keepdims = false,
8807  bool exclude = false) {
8808  return Operator("nanprod")
8809  .SetParam("axis", axis)
8810  .SetParam("keepdims", keepdims)
8811  .SetParam("exclude", exclude)
8812  .SetInput("data", data)
8813  .CreateSymbol();
8814 }
8815 
8839 inline Symbol max(Symbol data,
8840  Shape axis = Shape(),
8841  bool keepdims = false,
8842  bool exclude = false) {
8843  return Operator("max")
8844  .SetParam("axis", axis)
8845  .SetParam("keepdims", keepdims)
8846  .SetParam("exclude", exclude)
8847  .SetInput("data", data)
8848  .CreateSymbol();
8849 }
8850 
8874 inline Symbol min(Symbol data,
8875  Shape axis = Shape(),
8876  bool keepdims = false,
8877  bool exclude = false) {
8878  return Operator("min")
8879  .SetParam("axis", axis)
8880  .SetParam("keepdims", keepdims)
8881  .SetParam("exclude", exclude)
8882  .SetInput("data", data)
8883  .CreateSymbol();
8884 }
8885 
8915  Shape axis = Shape(),
8916  Shape size = Shape()) {
8917  return Operator("broadcast_axis")
8918  .SetParam("axis", axis)
8919  .SetParam("size", size)
8920  .SetInput("data", data)
8921  .CreateSymbol();
8922 }
8923 
8952  Shape shape = Shape()) {
8953  return Operator("broadcast_to")
8954  .SetParam("shape", shape)
8955  .SetInput("data", data)
8956  .CreateSymbol();
8957 }
8958 
8975 inline Symbol norm(Symbol data) {
8976  return Operator("norm")
8977  .SetInput("data", data)
8978  .CreateSymbol();
8979 }
8980 
9021 inline Symbol topk(Symbol data,
9022  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9023  int k = 1,
9024  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
9025  bool is_ascend = false) {
9026  static const char *TopkRetTypValues[] = {
9027  "both",
9028  "indices",
9029  "mask",
9030  "value"
9031  };
9032  return Operator("topk")
9033  .SetParam("axis", axis)
9034  .SetParam("k", k)
9035  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
9036  .SetParam("is_ascend", is_ascend)
9037  .SetInput("data", data)
9038  .CreateSymbol();
9039 }
9040 
9072 inline Symbol sort(Symbol data,
9073  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9074  bool is_ascend = true) {
9075  return Operator("sort")
9076  .SetParam("axis", axis)
9077  .SetParam("is_ascend", is_ascend)
9078  .SetInput("data", data)
9079  .CreateSymbol();
9080 }
9081 
9111 inline Symbol argsort(Symbol data,
9112  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9113  bool is_ascend = true) {
9114  return Operator("argsort")
9115  .SetParam("axis", axis)
9116  .SetParam("is_ascend", is_ascend)
9117  .SetInput("data", data)
9118  .CreateSymbol();
9119 }
9120 
9136  Symbol rhs) {
9137  return Operator("elemwise_add")
9138  .SetInput("lhs", lhs)
9139  .SetInput("rhs", rhs)
9140  .CreateSymbol();
9141 }
9142 
9158  Symbol rhs) {
9159  return Operator("elemwise_sub")
9160  .SetInput("lhs", lhs)
9161  .SetInput("rhs", rhs)
9162  .CreateSymbol();
9163 }
9164 
9183  Symbol rhs) {
9184  return Operator("elemwise_mul")
9185  .SetInput("lhs", lhs)
9186  .SetInput("rhs", rhs)
9187  .CreateSymbol();
9188 }
9189 
9201  Symbol rhs) {
9202  return Operator("elemwise_div")
9203  .SetInput("lhs", lhs)
9204  .SetInput("rhs", rhs)
9205  .CreateSymbol();
9206 }
9207 
9259  Symbol weight,
9260  int input_dim,
9261  int output_dim,
9263  static const char *EmbeddingDtypeValues[] = {
9264  "float16",
9265  "float32",
9266  "float64",
9267  "int32",
9268  "uint8"
9269  };
9270  return Operator("Embedding")
9271  .SetParam("input_dim", input_dim)
9272  .SetParam("output_dim", output_dim)
9273  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
9274  .SetInput("data", data)
9275  .SetInput("weight", weight)
9276  .CreateSymbol();
9277 }
9278 
9317 inline Symbol take(Symbol a,
9318  Symbol indices,
9319  int axis = 0,
9320  TakeMode mode = TakeMode::kClip) {
9321  static const char *TakeModeValues[] = {
9322  "clip",
9323  "raise",
9324  "wrap"
9325  };
9326  return Operator("take")
9327  .SetParam("axis", axis)
9328  .SetParam("mode", TakeModeValues[int(mode)])
9329  .SetInput("a", a)
9330  .SetInput("indices", indices)
9331  .CreateSymbol();
9332 }
9333 
9362  Symbol indices) {
9363  return Operator("batch_take")
9364  .SetInput("a", a)
9365  .SetInput("indices", indices)
9366  .CreateSymbol();
9367 }
9368 
9412 inline Symbol one_hot(Symbol indices,
9413  int depth,
9414  double on_value = 1,
9415  double off_value = 0,
9417  static const char *One_hotDtypeValues[] = {
9418  "float16",
9419  "float32",
9420  "float64",
9421  "int32",
9422  "uint8"
9423  };
9424  return Operator("one_hot")
9425  .SetParam("depth", depth)
9426  .SetParam("on_value", on_value)
9427  .SetParam("off_value", off_value)
9428  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
9429  .SetInput("indices", indices)
9430  .CreateSymbol();
9431 }
9432 
9460  Symbol indices) {
9461  return Operator("gather_nd")
9462  .SetInput("data", data)
9463  .SetInput("indices", indices)
9464  .CreateSymbol();
9465 }
9466 
9503  Symbol indices,
9504  Shape shape) {
9505  return Operator("scatter_nd")
9506  .SetParam("shape", shape)
9507  .SetInput("data", data)
9508  .SetInput("indices", indices)
9509  .CreateSymbol();
9510 }
9511 
9534  Symbol rhs) {
9535  return Operator("broadcast_equal")
9536  .SetInput("lhs", lhs)
9537  .SetInput("rhs", rhs)
9538  .CreateSymbol();
9539 }
9540 
9563  Symbol rhs) {
9564  return Operator("broadcast_not_equal")
9565  .SetInput("lhs", lhs)
9566  .SetInput("rhs", rhs)
9567  .CreateSymbol();
9568 }
9569 
9592  Symbol rhs) {
9593  return Operator("broadcast_greater")
9594  .SetInput("lhs", lhs)
9595  .SetInput("rhs", rhs)
9596  .CreateSymbol();
9597 }
9598 
9621  Symbol rhs) {
9622  return Operator("broadcast_greater_equal")
9623  .SetInput("lhs", lhs)
9624  .SetInput("rhs", rhs)
9625  .CreateSymbol();
9626 }
9627 
9650  Symbol rhs) {
9651  return Operator("broadcast_lesser")
9652  .SetInput("lhs", lhs)
9653  .SetInput("rhs", rhs)
9654  .CreateSymbol();
9655 }
9656 
9679  Symbol rhs) {
9680  return Operator("broadcast_lesser_equal")
9681  .SetInput("lhs", lhs)
9682  .SetInput("rhs", rhs)
9683  .CreateSymbol();
9684 }
9685 
9701 inline Symbol where(Symbol condition,
9702  Symbol x,
9703  Symbol y) {
9704  return Operator("where")
9705  .SetInput("condition", condition)
9706  .SetInput("x", x)
9707  .SetInput("y", y)
9708  .CreateSymbol();
9709 }
9710 
9736  mx_float scalar) {
9737  return Operator("smooth_l1")
9738  .SetParam("scalar", scalar)
9739  .SetInput("data", data)
9740  .CreateSymbol();
9741 }
9742 
9786  Cast_storageStype stype) {
9787  static const char *Cast_storageStypeValues[] = {
9788  "csr",
9789  "default",
9790  "row_sparse"
9791  };
9792  return Operator("cast_storage")
9793  .SetParam("stype", Cast_storageStypeValues[int(stype)])
9794  .SetInput("data", data)
9795  .CreateSymbol();
9796 }
9797 
9817 inline Symbol sin(Symbol data) {
9818  return Operator("sin")
9819  .SetInput("data", data)
9820  .CreateSymbol();
9821 }
9822 
9839 inline Symbol cos(Symbol data) {
9840  return Operator("cos")
9841  .SetInput("data", data)
9842  .CreateSymbol();
9843 }
9844 
9864 inline Symbol tan(Symbol data) {
9865  return Operator("tan")
9866  .SetInput("data", data)
9867  .CreateSymbol();
9868 }
9869 
9890 inline Symbol arcsin(Symbol data) {
9891  return Operator("arcsin")
9892  .SetInput("data", data)
9893  .CreateSymbol();
9894 }
9895 
9913 inline Symbol arccos(Symbol data) {
9914  return Operator("arccos")
9915  .SetInput("data", data)
9916  .CreateSymbol();
9917 }
9918 
9938 inline Symbol arctan(Symbol data) {
9939  return Operator("arctan")
9940  .SetInput("data", data)
9941  .CreateSymbol();
9942 }
9943 
9961 inline Symbol degrees(Symbol data) {
9962  return Operator("degrees")
9963  .SetInput("data", data)
9964  .CreateSymbol();
9965 }
9966 
9984 inline Symbol radians(Symbol data) {
9985  return Operator("radians")
9986  .SetInput("data", data)
9987  .CreateSymbol();
9988 }
9989 
10007 inline Symbol sinh(Symbol data) {
10008  return Operator("sinh")
10009  .SetInput("data", data)
10010  .CreateSymbol();
10011 }
10012 
10027 inline Symbol cosh(Symbol data) {
10028  return Operator("cosh")
10029  .SetInput("data", data)
10030  .CreateSymbol();
10031 }
10032 
10050 inline Symbol tanh(Symbol data) {
10051  return Operator("tanh")
10052  .SetInput("data", data)
10053  .CreateSymbol();
10054 }
10055 
10071 inline Symbol arcsinh(Symbol data) {
10072  return Operator("arcsinh")
10073  .SetInput("data", data)
10074  .CreateSymbol();
10075 }
10076 
10089 inline Symbol arccosh(Symbol data) {
10090  return Operator("arccosh")
10091  .SetInput("data", data)
10092  .CreateSymbol();
10093 }
10094 
10110 inline Symbol arctanh(Symbol data) {
10111  return Operator("arctanh")
10112  .SetInput("data", data)
10113  .CreateSymbol();
10114 }
10115 
10130 inline Symbol Custom(const std::vector<Symbol>& data,
10131  const std::string& op_type) {
10132  return Operator("Custom")
10133 (data)
10134  .CreateSymbol();
10135 }
10136 
10165 inline Symbol softmax(Symbol data,
10166  int axis = -1) {
10167  return Operator("softmax")
10168  .SetParam("axis", axis)
10169  .SetInput("data", data)
10170  .CreateSymbol();
10171 }
10172 
10195  int axis = -1) {
10196  return Operator("log_softmax")
10197  .SetParam("axis", axis)
10198  .SetInput("data", data)
10199  .CreateSymbol();
10200 }
10201 
10229  mx_float slope = 0.25,
10230  mx_float lower_bound = 0.125,
10231  mx_float upper_bound = 0.334) {
10232  static const char *LeakyReLUActTypeValues[] = {
10233  "elu",
10234  "leaky",
10235  "prelu",
10236  "rrelu"
10237  };
10238  return Operator("LeakyReLU")
10239  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
10240  .SetParam("slope", slope)
10241  .SetParam("lower_bound", lower_bound)
10242  .SetParam("upper_bound", upper_bound)
10243  .SetInput("data", data)
10244  .CreateSymbol();
10245 }
10246 
10274 inline Symbol SwapAxis(Symbol data,
10275  uint32_t dim1 = 0,
10276  uint32_t dim2 = 0) {
10277  return Operator("SwapAxis")
10278  .SetParam("dim1", dim1)
10279  .SetParam("dim2", dim2)
10280  .SetInput("data", data)
10281  .CreateSymbol();
10282 }
10283 
10339  Symbol gamma,
10340  Symbol beta,
10341  mx_float eps = 0.001,
10342  mx_float momentum = 0.9,
10343  bool fix_gamma = true,
10344  bool use_global_stats = false,
10345  bool output_mean_var = false) {
10346  return Operator("BatchNorm_v1")
10347  .SetParam("eps", eps)
10348  .SetParam("momentum", momentum)
10349  .SetParam("fix_gamma", fix_gamma)
10350  .SetParam("use_global_stats", use_global_stats)
10351  .SetParam("output_mean_var", output_mean_var)
10352  .SetInput("data", data)
10353  .SetInput("gamma", gamma)
10354  .SetInput("beta", beta)
10355  .CreateSymbol();
10356 }
10357 
10398 inline Symbol Concat(const std::vector<Symbol>& data,
10399  int num_args,
10400  int dim = 1) {
10401  return Operator("Concat")
10402  .SetParam("num_args", num_args)
10403  .SetParam("dim", dim)
10404 (data)
10405  .CreateSymbol();
10406 }
10407 
10434 inline Symbol sgd_update(Symbol weight,
10435  Symbol grad,
10436  mx_float lr,
10437  mx_float wd = 0,
10438  mx_float rescale_grad = 1,
10439  mx_float clip_gradient = -1) {
10440  return Operator("sgd_update")
10441  .SetParam("lr", lr)
10442  .SetParam("wd", wd)
10443  .SetParam("rescale_grad", rescale_grad)
10444  .SetParam("clip_gradient", clip_gradient)
10445  .SetInput("weight", weight)
10446  .SetInput("grad", grad)
10447  .CreateSymbol();
10448 }
10449 
10492  Symbol grad,
10493  Symbol mom,
10494  mx_float lr,
10495  mx_float momentum = 0,
10496  mx_float wd = 0,
10497  mx_float rescale_grad = 1,
10498  mx_float clip_gradient = -1) {
10499  return Operator("sgd_mom_update")
10500  .SetParam("lr", lr)
10501  .SetParam("momentum", momentum)
10502  .SetParam("wd", wd)
10503  .SetParam("rescale_grad", rescale_grad)
10504  .SetParam("clip_gradient", clip_gradient)
10505  .SetInput("weight", weight)
10506  .SetInput("grad", grad)
10507  .SetInput("mom", mom)
10508  .CreateSymbol();
10509 }
10510 
10525  Symbol grad,
10526  Symbol weight32,
10527  mx_float lr,
10528  mx_float wd = 0,
10529  mx_float rescale_grad = 1,
10530  mx_float clip_gradient = -1) {
10531  return Operator("mp_sgd_update")
10532  .SetParam("lr", lr)
10533  .SetParam("wd", wd)
10534  .SetParam("rescale_grad", rescale_grad)
10535  .SetParam("clip_gradient", clip_gradient)
10536  .SetInput("weight", weight)
10537  .SetInput("grad", grad)
10538  .SetInput("weight32", weight32)
10539  .CreateSymbol();
10540 }
10541 
10558  Symbol grad,
10559  Symbol mom,
10560  Symbol weight32,
10561  mx_float lr,
10562  mx_float momentum = 0,
10563  mx_float wd = 0,
10564  mx_float rescale_grad = 1,
10565  mx_float clip_gradient = -1) {
10566  return Operator("mp_sgd_mom_update")
10567  .SetParam("lr", lr)
10568  .SetParam("momentum", momentum)
10569  .SetParam("wd", wd)
10570  .SetParam("rescale_grad", rescale_grad)
10571  .SetParam("clip_gradient", clip_gradient)
10572  .SetInput("weight", weight)
10573  .SetInput("grad", grad)
10574  .SetInput("mom", mom)
10575  .SetInput("weight32", weight32)
10576  .CreateSymbol();
10577 }
10578 
10625 inline Symbol adam_update(Symbol weight,
10626  Symbol grad,
10627  Symbol mean,
10628  Symbol var,
10629  mx_float lr,
10630  mx_float beta1 = 0.9,
10631  mx_float beta2 = 0.999,
10632  mx_float epsilon = 1e-08,
10633  mx_float wd = 0,
10634  mx_float rescale_grad = 1,
10635  mx_float clip_gradient = -1) {
10636  return Operator("adam_update")
10637  .SetParam("lr", lr)
10638  .SetParam("beta1", beta1)
10639  .SetParam("beta2", beta2)
10640  .SetParam("epsilon", epsilon)
10641  .SetParam("wd", wd)
10642  .SetParam("rescale_grad", rescale_grad)
10643  .SetParam("clip_gradient", clip_gradient)
10644  .SetInput("weight", weight)
10645  .SetInput("grad", grad)
10646  .SetInput("mean", mean)
10647  .SetInput("var", var)
10648  .CreateSymbol();
10649 }
10650 
10704  Symbol grad,
10705  Symbol n,
10706  mx_float lr,
10707  mx_float gamma1 = 0.95,
10708  mx_float epsilon = 1e-08,
10709  mx_float wd = 0,
10710  mx_float rescale_grad = 1,
10711  mx_float clip_gradient = -1,
10712  mx_float clip_weights = -1) {
10713  return Operator("rmsprop_update")
10714  .SetParam("lr", lr)
10715  .SetParam("gamma1", gamma1)
10716  .SetParam("epsilon", epsilon)
10717  .SetParam("wd", wd)
10718  .SetParam("rescale_grad", rescale_grad)
10719  .SetParam("clip_gradient", clip_gradient)
10720  .SetParam("clip_weights", clip_weights)
10721  .SetInput("weight", weight)
10722  .SetInput("grad", grad)
10723  .SetInput("n", n)
10724  .CreateSymbol();
10725 }
10726 
10772  Symbol grad,
10773  Symbol n,
10774  Symbol g,
10775  Symbol delta,
10776  mx_float lr,
10777  mx_float gamma1 = 0.95,
10778  mx_float gamma2 = 0.9,
10779  mx_float epsilon = 1e-08,
10780  mx_float wd = 0,
10781  mx_float rescale_grad = 1,
10782  mx_float clip_gradient = -1,
10783  mx_float clip_weights = -1) {
10784  return Operator("rmspropalex_update")
10785  .SetParam("lr", lr)
10786  .SetParam("gamma1", gamma1)
10787  .SetParam("gamma2", gamma2)
10788  .SetParam("epsilon", epsilon)
10789  .SetParam("wd", wd)
10790  .SetParam("rescale_grad", rescale_grad)
10791  .SetParam("clip_gradient", clip_gradient)
10792  .SetParam("clip_weights", clip_weights)
10793  .SetInput("weight", weight)
10794  .SetInput("grad", grad)
10795  .SetInput("n", n)
10796  .SetInput("g", g)
10797  .SetInput("delta", delta)
10798  .CreateSymbol();
10799 }
10800 
10839 inline Symbol ftrl_update(Symbol weight,
10840  Symbol grad,
10841  Symbol z,
10842  Symbol n,
10843  mx_float lr,
10844  mx_float lamda1 = 0.01,
10845  mx_float beta = 1,
10846  mx_float wd = 0,
10847  mx_float rescale_grad = 1,
10848  mx_float clip_gradient = -1) {
10849  return Operator("ftrl_update")
10850  .SetParam("lr", lr)
10851  .SetParam("lamda1", lamda1)
10852  .SetParam("beta", beta)
10853  .SetParam("wd", wd)
10854  .SetParam("rescale_grad", rescale_grad)
10855  .SetParam("clip_gradient", clip_gradient)
10856  .SetInput("weight", weight)
10857  .SetInput("grad", grad)
10858  .SetInput("z", z)
10859  .SetInput("n", n)
10860  .CreateSymbol();
10861 }
10862 
10958 inline Symbol Pad(Symbol data,
10959  PadMode mode,
10960  Shape pad_width,
10961  double constant_value = 0) {
10962  static const char *PadModeValues[] = {
10963  "constant",
10964  "edge",
10965  "reflect"
10966  };
10967  return Operator("Pad")
10968  .SetParam("mode", PadModeValues[int(mode)])
10969  .SetParam("pad_width", pad_width)
10970  .SetParam("constant_value", constant_value)
10971  .SetInput("data", data)
10972  .CreateSymbol();
10973 }
10974 
10984  mx_float sparseness_target = 0.1,
10985  mx_float penalty = 0.001,
10986  mx_float momentum = 0.9) {
10987  return Operator("IdentityAttachKLSparseReg")
10988  .SetParam("sparseness_target", sparseness_target)
10989  .SetParam("penalty", penalty)
10990  .SetParam("momentum", momentum)
10991  .SetInput("data", data)
10992  .CreateSymbol();
10993 }
10994 
11066  int num_outputs,
11067  int axis = 1,
11068  bool squeeze_axis = false) {
11069  return Operator("SliceChannel")
11070  .SetParam("num_outputs", num_outputs)
11071  .SetParam("axis", axis)
11072  .SetParam("squeeze_axis", squeeze_axis)
11073  .SetInput("data", data)
11074  .CreateSymbol();
11075 }
11076 
11114  Symbol label) {
11115  return Operator("softmax_cross_entropy")
11116  .SetInput("data", data)
11117  .SetInput("label", label)
11118  .CreateSymbol();
11119 }
11120 
11135 inline Symbol UpSampling(const std::vector<Symbol>& data,
11136  uint32_t scale,
11137  UpSamplingSampleType sample_type,
11138  int num_args,
11139  uint32_t num_filter = 0,
11141  uint64_t workspace = 512) {
11142  static const char *UpSamplingSampleTypeValues[] = {
11143  "bilinear",
11144  "nearest"
11145  };
11146  static const char *UpSamplingMultiInputModeValues[] = {
11147  "concat",
11148  "sum"
11149  };
11150  return Operator("UpSampling")
11151  .SetParam("scale", scale)
11152  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
11153  .SetParam("num_args", num_args)
11154  .SetParam("num_filter", num_filter)
11155  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
11156  .SetParam("workspace", workspace)
11157 (data)
11158  .CreateSymbol();
11159 }
11160 
11224  Symbol gamma,
11225  Symbol beta,
11226  Symbol moving_mean,
11227  Symbol moving_var,
11228  double eps = 0.001,
11229  mx_float momentum = 0.9,
11230  bool fix_gamma = true,
11231  bool use_global_stats = false,
11232  bool output_mean_var = false,
11233  int axis = 1,
11234  bool cudnn_off = false) {
11235  return Operator("BatchNorm")
11236  .SetParam("eps", eps)
11237  .SetParam("momentum", momentum)
11238  .SetParam("fix_gamma", fix_gamma)
11239  .SetParam("use_global_stats", use_global_stats)
11240  .SetParam("output_mean_var", output_mean_var)
11241  .SetParam("axis", axis)
11242  .SetParam("cudnn_off", cudnn_off)
11243  .SetInput("data", data)
11244  .SetInput("gamma", gamma)
11245  .SetInput("beta", beta)
11246  .SetInput("moving_mean", moving_mean)
11247  .SetInput("moving_var", moving_var)
11248  .CreateSymbol();
11249 }
11250 
11301  Symbol gamma,
11302  Symbol beta,
11303  mx_float eps = 0.001) {
11304  return Operator("InstanceNorm")
11305  .SetParam("eps", eps)
11306  .SetInput("data", data)
11307  .SetInput("gamma", gamma)
11308  .SetInput("beta", beta)
11309  .CreateSymbol();
11310 }
11311 
11326 inline Symbol RNN(Symbol data,
11327  Symbol parameters,
11328  Symbol state,
11329  Symbol state_cell,
11330  uint32_t state_size,
11331  uint32_t num_layers,
11332  RNNMode mode,
11333  bool bidirectional = false,
11334  mx_float p = 0,
11335  bool state_outputs = false) {
11336  static const char *RNNModeValues[] = {
11337  "gru",
11338  "lstm",
11339  "rnn_relu",
11340  "rnn_tanh"
11341  };
11342  return Operator("RNN")
11343  .SetParam("state_size", state_size)
11344  .SetParam("num_layers", num_layers)
11345  .SetParam("mode", RNNModeValues[int(mode)])
11346  .SetParam("bidirectional", bidirectional)
11347  .SetParam("p", p)
11348  .SetParam("state_outputs", state_outputs)
11349  .SetInput("data", data)
11350  .SetInput("parameters", parameters)
11351  .SetInput("state", state)
11352  .SetInput("state_cell", state_cell)
11353  .CreateSymbol();
11354 }
11355 
11384  Symbol weight,
11385  Symbol bias,
11386  Shape kernel,
11387  uint32_t num_filter,
11388  Shape stride = Shape(),
11389  Shape dilate = Shape(),
11390  Shape pad = Shape(),
11391  uint32_t num_group = 1,
11392  uint64_t workspace = 1024,
11393  bool no_bias = false,
11395  bool cudnn_off = false,
11397  static const char *Convolution_v1CudnnTuneValues[] = {
11398  "None",
11399  "fastest",
11400  "limited_workspace",
11401  "off"
11402  };
11403  static const char *Convolution_v1LayoutValues[] = {
11404  "None",
11405  "NCDHW",
11406  "NCHW",
11407  "NDHWC",
11408  "NHWC"
11409  };
11410  return Operator("Convolution_v1")
11411  .SetParam("kernel", kernel)
11412  .SetParam("num_filter", num_filter)
11413  .SetParam("stride", stride)
11414  .SetParam("dilate", dilate)
11415  .SetParam("pad", pad)
11416  .SetParam("num_group", num_group)
11417  .SetParam("workspace", workspace)
11418  .SetParam("no_bias", no_bias)
11419  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
11420  .SetParam("cudnn_off", cudnn_off)
11421  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
11422  .SetInput("data", data)
11423  .SetInput("weight", weight)
11424  .SetInput("bias", bias)
11425  .CreateSymbol();
11426 }
11427 
11447 inline Symbol Crop(const std::vector<Symbol>& data,
11448  int num_args,
11449  Shape offset = Shape(0,0),
11450  Shape h_w = Shape(0,0),
11451  bool center_crop = false) {
11452  return Operator("Crop")
11453  .SetParam("num_args", num_args)
11454  .SetParam("offset", offset)
11455  .SetParam("h_w", h_w)
11456  .SetParam("center_crop", center_crop)
11457 (data)
11458  .CreateSymbol();
11459 }
11460 
11471  Symbol loc,
11472  SpatialTransformerTransformType transform_type,
11473  SpatialTransformerSamplerType sampler_type,
11474  Shape target_shape = Shape(0,0)) {
11475  static const char *SpatialTransformerTransformTypeValues[] = {
11476  "affine"
11477  };
11478  static const char *SpatialTransformerSamplerTypeValues[] = {
11479  "bilinear"
11480  };
11481  return Operator("SpatialTransformer")
11482  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
11483  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
11484  .SetParam("target_shape", target_shape)
11485  .SetInput("data", data)
11486  .SetInput("loc", loc)
11487  .CreateSymbol();
11488 }
11489 
11516  Symbol weight,
11517  Symbol bias,
11518  Shape kernel,
11519  uint32_t num_filter,
11520  Shape stride = Shape(),
11521  Shape dilate = Shape(),
11522  Shape pad = Shape(),
11523  Shape adj = Shape(),
11524  Shape target_shape = Shape(),
11525  uint32_t num_group = 1,
11526  uint64_t workspace = 512,
11527  bool no_bias = true,
11529  bool cudnn_off = false,
11531  static const char *DeconvolutionCudnnTuneValues[] = {
11532  "None",
11533  "fastest",
11534  "limited_workspace",
11535  "off"
11536  };
11537  static const char *DeconvolutionLayoutValues[] = {
11538  "None",
11539  "NCDHW",
11540  "NCHW",
11541  "NCW",
11542  "NDHWC",
11543  "NHWC"
11544  };
11545  return Operator("Deconvolution")
11546  .SetParam("kernel", kernel)
11547  .SetParam("num_filter", num_filter)
11548  .SetParam("stride", stride)
11549  .SetParam("dilate", dilate)
11550  .SetParam("pad", pad)
11551  .SetParam("adj", adj)
11552  .SetParam("target_shape", target_shape)
11553  .SetParam("num_group", num_group)
11554  .SetParam("workspace", workspace)
11555  .SetParam("no_bias", no_bias)
11556  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
11557  .SetParam("cudnn_off", cudnn_off)
11558  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
11559  .SetInput("data", data)
11560  .SetInput("weight", weight)
11561  .SetInput("bias", bias)
11562  .CreateSymbol();
11563 }
11564 
11660  Symbol label,
11661  mx_float grad_scale = 1,
11662  mx_float ignore_label = -1,
11663  bool multi_output = false,
11664  bool use_ignore = false,
11665  bool preserve_shape = false,
11667  bool out_grad = false,
11668  mx_float smooth_alpha = 0) {
11669  static const char *SoftmaxOutputNormalizationValues[] = {
11670  "batch",
11671  "null",
11672  "valid"
11673  };
11674  return Operator("SoftmaxOutput")
11675  .SetParam("grad_scale", grad_scale)
11676  .SetParam("ignore_label", ignore_label)
11677  .SetParam("multi_output", multi_output)
11678  .SetParam("use_ignore", use_ignore)
11679  .SetParam("preserve_shape", preserve_shape)
11680  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
11681  .SetParam("out_grad", out_grad)
11682  .SetParam("smooth_alpha", smooth_alpha)
11683  .SetInput("data", data)
11684  .SetInput("label", label)
11685  .CreateSymbol();
11686 }
11687 
11714 inline Symbol Softmax(Symbol data,
11715  mx_float grad_scale = 1,
11716  mx_float ignore_label = -1,
11717  bool multi_output = false,
11718  bool use_ignore = false,
11719  bool preserve_shape = false,
11721  bool out_grad = false,
11722  mx_float smooth_alpha = 0) {
11723  static const char *SoftmaxNormalizationValues[] = {
11724  "batch",
11725  "null",
11726  "valid"
11727  };
11728  return Operator("Softmax")
11729  .SetParam("grad_scale", grad_scale)
11730  .SetParam("ignore_label", ignore_label)
11731  .SetParam("multi_output", multi_output)
11732  .SetParam("use_ignore", use_ignore)
11733  .SetParam("preserve_shape", preserve_shape)
11734  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
11735  .SetParam("out_grad", out_grad)
11736  .SetParam("smooth_alpha", smooth_alpha)
11737  .SetInput("data", data)
11738  .CreateSymbol();
11739 }
11740 
11816  Symbol sequence_length,
11817  bool use_sequence_length = false) {
11818  return Operator("SequenceReverse")
11819  .SetParam("use_sequence_length", use_sequence_length)
11820  .SetInput("data", data)
11821  .SetInput("sequence_length", sequence_length)
11822  .CreateSymbol();
11823 }
11824 
11879  Symbol sequence_length,
11880  bool use_sequence_length = false) {
11881  return Operator("SequenceLast")
11882  .SetParam("use_sequence_length", use_sequence_length)
11883  .SetInput("data", data)
11884  .SetInput("sequence_length", sequence_length)
11885  .CreateSymbol();
11886 }
11887 
11936  Symbol data2,
11937  uint32_t kernel_size = 1,
11938  uint32_t max_displacement = 1,
11939  uint32_t stride1 = 1,
11940  uint32_t stride2 = 1,
11941  uint32_t pad_size = 0,
11942  bool is_multiply = true) {
11943  return Operator("Correlation")
11944  .SetParam("kernel_size", kernel_size)
11945  .SetParam("max_displacement", max_displacement)
11946  .SetParam("stride1", stride1)
11947  .SetParam("stride2", stride2)
11948  .SetParam("pad_size", pad_size)
11949  .SetParam("is_multiply", is_multiply)
11950  .SetInput("data1", data1)
11951  .SetInput("data2", data2)
11952  .CreateSymbol();
11953 }
11954 
11970  Symbol label,
11971  mx_float margin = 1,
11972  mx_float regularization_coefficient = 1,
11973  bool use_linear = false) {
11974  return Operator("SVMOutput")
11975  .SetParam("margin", margin)
11976  .SetParam("regularization_coefficient", regularization_coefficient)
11977  .SetParam("use_linear", use_linear)
11978  .SetInput("data", data)
11979  .SetInput("label", label)
11980  .CreateSymbol();
11981 }
11982 
12045  mx_float eps = 1e-10,
12047  static const char *L2NormalizationModeValues[] = {
12048  "channel",
12049  "instance",
12050  "spatial"
12051  };
12052  return Operator("L2Normalization")
12053  .SetParam("eps", eps)
12054  .SetParam("mode", L2NormalizationModeValues[int(mode)])
12055  .SetInput("data", data)
12056  .CreateSymbol();
12057 }
12058 
12085 inline Symbol LRN(Symbol data,
12086  uint32_t nsize,
12087  mx_float alpha = 0.0001,
12088  mx_float beta = 0.75,
12089  mx_float knorm = 2) {
12090  return Operator("LRN")
12091  .SetParam("nsize", nsize)
12092  .SetParam("alpha", alpha)
12093  .SetParam("beta", beta)
12094  .SetParam("knorm", knorm)
12095  .SetInput("data", data)
12096  .CreateSymbol();
12097 }
12098 
12132  Symbol weight,
12133  Symbol bias,
12134  int num_hidden,
12135  bool no_bias = false,
12136  bool flatten = true) {
12137  return Operator("FullyConnected")
12138  .SetParam("num_hidden", num_hidden)
12139  .SetParam("no_bias", no_bias)
12140  .SetParam("flatten", flatten)
12141  .SetInput("data", data)
12142  .SetInput("weight", weight)
12143  .SetInput("bias", bias)
12144  .CreateSymbol();
12145 }
12146 
12224  Symbol sequence_length,
12225  bool use_sequence_length = false,
12226  mx_float value = 0) {
12227  return Operator("SequenceMask")
12228  .SetParam("use_sequence_length", use_sequence_length)
12229  .SetParam("value", value)
12230  .SetInput("data", data)
12231  .SetInput("sequence_length", sequence_length)
12232  .CreateSymbol();
12233 }
12234 
12245  GridGeneratorTransformType transform_type,
12246  Shape target_shape = Shape(0,0)) {
12247  static const char *GridGeneratorTransformTypeValues[] = {
12248  "affine",
12249  "warp"
12250  };
12251  return Operator("GridGenerator")
12252  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
12253  .SetParam("target_shape", target_shape)
12254  .SetInput("data", data)
12255  .CreateSymbol();
12256 }
12257 
12309  Shape kernel,
12310  Pooling_v1PoolType pool_type,
12311  bool global_pool = false,
12313  Shape stride = Shape(),
12314  Shape pad = Shape()) {
12315  static const char *Pooling_v1PoolTypeValues[] = {
12316  "avg",
12317  "max",
12318  "sum"
12319  };
12320  static const char *Pooling_v1PoolingConventionValues[] = {
12321  "full",
12322  "valid"
12323  };
12324  return Operator("Pooling_v1")
12325  .SetParam("kernel", kernel)
12326  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
12327  .SetParam("global_pool", global_pool)
12328  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
12329  .SetParam("stride", stride)
12330  .SetParam("pad", pad)
12331  .SetInput("data", data)
12332  .CreateSymbol();
12333 }
12334 
12429  Symbol weight,
12430  Symbol bias,
12431  Shape kernel,
12432  uint32_t num_filter,
12433  Shape stride = Shape(),
12434  Shape dilate = Shape(),
12435  Shape pad = Shape(),
12436  uint32_t num_group = 1,
12437  uint64_t workspace = 1024,
12438  bool no_bias = false,
12440  bool cudnn_off = false,
12442  static const char *ConvolutionCudnnTuneValues[] = {
12443  "None",
12444  "fastest",
12445  "limited_workspace",
12446  "off"
12447  };
12448  static const char *ConvolutionLayoutValues[] = {
12449  "None",
12450  "NCDHW",
12451  "NCHW",
12452  "NCW",
12453  "NDHWC",
12454  "NHWC"
12455  };
12456  return Operator("Convolution")
12457  .SetParam("kernel", kernel)
12458  .SetParam("num_filter", num_filter)
12459  .SetParam("stride", stride)
12460  .SetParam("dilate", dilate)
12461  .SetParam("pad", pad)
12462  .SetParam("num_group", num_group)
12463  .SetParam("workspace", workspace)
12464  .SetParam("no_bias", no_bias)
12465  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
12466  .SetParam("cudnn_off", cudnn_off)
12467  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
12468  .SetInput("data", data)
12469  .SetInput("weight", weight)
12470  .SetInput("bias", bias)
12471  .CreateSymbol();
12472 }
12473 
12554  Symbol grid) {
12555  return Operator("BilinearSampler")
12556  .SetInput("data", data)
12557  .SetInput("grid", grid)
12558  .CreateSymbol();
12559 }
12560 
12613 inline Symbol Pooling(Symbol data,
12614  Shape kernel,
12615  PoolingPoolType pool_type,
12616  bool global_pool = false,
12617  bool cudnn_off = false,
12619  Shape stride = Shape(),
12620  Shape pad = Shape()) {
12621  static const char *PoolingPoolTypeValues[] = {
12622  "avg",
12623  "max",
12624  "sum"
12625  };
12626  static const char *PoolingPoolingConventionValues[] = {
12627  "full",
12628  "valid"
12629  };
12630  return Operator("Pooling")
12631  .SetParam("kernel", kernel)
12632  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
12633  .SetParam("global_pool", global_pool)
12634  .SetParam("cudnn_off", cudnn_off)
12635  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
12636  .SetParam("stride", stride)
12637  .SetParam("pad", pad)
12638  .SetInput("data", data)
12639  .CreateSymbol();
12640 }
12641 
12680 inline Symbol Dropout(Symbol data,
12681  mx_float p = 0.5,
12683  static const char *DropoutModeValues[] = {
12684  "always",
12685  "training"
12686  };
12687  return Operator("Dropout")
12688  .SetParam("p", p)
12689  .SetParam("mode", DropoutModeValues[int(mode)])
12690  .SetInput("data", data)
12691  .CreateSymbol();
12692 }
12693 
12712  ActivationActType act_type) {
12713  static const char *ActivationActTypeValues[] = {
12714  "relu",
12715  "sigmoid",
12716  "softrelu",
12717  "tanh"
12718  };
12719  return Operator("Activation")
12720  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
12721  .SetInput("data", data)
12722  .CreateSymbol();
12723 }
12724 
12781  Symbol rois,
12782  Shape pooled_size,
12783  mx_float spatial_scale) {
12784  return Operator("ROIPooling")
12785  .SetParam("pooled_size", pooled_size)
12786  .SetParam("spatial_scale", spatial_scale)
12787  .SetInput("data", data)
12788  .SetInput("rois", rois)
12789  .CreateSymbol();
12790 }
12791 
12816  Symbol label,
12817  mx_float grad_scale = 1) {
12818  return Operator("LinearRegressionOutput")
12819  .SetParam("grad_scale", grad_scale)
12820  .SetInput("data", data)
12821  .SetInput("label", label)
12822  .CreateSymbol();
12823 }
12824 
12850  Symbol label,
12851  mx_float grad_scale = 1) {
12852  return Operator("MAERegressionOutput")
12853  .SetParam("grad_scale", grad_scale)
12854  .SetInput("data", data)
12855  .SetInput("label", label)
12856  .CreateSymbol();
12857 }
12858 
12884  Symbol label,
12885  mx_float grad_scale = 1) {
12886  return Operator("LogisticRegressionOutput")
12887  .SetParam("grad_scale", grad_scale)
12888  .SetInput("data", data)
12889  .SetInput("label", label)
12890  .CreateSymbol();
12891 }
12892 
12927  static const char *SoftmaxActivationModeValues[] = {
12928  "channel",
12929  "instance"
12930  };
12931  return Operator("SoftmaxActivation")
12932  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
12933  .SetInput("data", data)
12934  .CreateSymbol();
12935 }
12936 
12970 inline Symbol MakeLoss(Symbol data,
12971  mx_float grad_scale = 1,
12972  mx_float valid_thresh = 0,
12974  static const char *MakeLossNormalizationValues[] = {
12975  "batch",
12976  "null",
12977  "valid"
12978  };
12979  return Operator("MakeLoss")
12980  .SetParam("grad_scale", grad_scale)
12981  .SetParam("valid_thresh", valid_thresh)
12982  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
12983  .SetInput("data", data)
12984  .CreateSymbol();
12985 }
12986 
12995  Symbol rhs) {
12996  return Operator("choose_element_0index")
12997  .SetInput("lhs", lhs)
12998  .SetInput("rhs", rhs)
12999  .CreateSymbol();
13000 }
13001 
13011  Symbol mhs,
13012  Symbol rhs) {
13013  return Operator("fill_element_0index")
13014  .SetInput("lhs", lhs)
13015  .SetInput("mhs", mhs)
13016  .SetInput("rhs", rhs)
13017  .CreateSymbol();
13018 }
13019 
13020 } //namespace cpp
13021 } //namespace mxnet
13022 #endif // MXNET_CPP_OP_H_
Symbol Convolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, ConvolutionCudnnTune cudnn_tune=ConvolutionCudnnTune::kNone, bool cudnn_off=false, ConvolutionLayout layout=ConvolutionLayout::kNone)
Definition: op.h:6150
Symbol fix(const std::string &symbol_name, Symbol data)
Definition: op.h:1621
Symbol Crop(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, Shape offset=Shape(0, 0), Shape h_w=Shape(0, 0), bool center_crop=false)
Definition: op.h:5039
Symbol min(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2219
Symbol broadcast_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:824
Symbol arcsin(const std::string &symbol_name, Symbol data)
Definition: op.h:3343
Symbol FullyConnected(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, int num_hidden, bool no_bias=false, bool flatten=true)
Definition: op.h:5801
Symbol arccosh(const std::string &symbol_name, Symbol data)
Definition: op.h:3560
Symbol arctan(const std::string &symbol_name, Symbol data)
Definition: op.h:3395
Symbol SwapAxis(const std::string &symbol_name, Symbol data, uint32_t dim1=0, uint32_t dim2=0)
Definition: op.h:3766
Symbol cast_storage(const std::string &symbol_name, Symbol data, Cast_storageStype stype)
Definition: op.h:3230
Symbol nansum(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2106
Symbol add_n(const std::string &symbol_name, const std::vector< Symbol > &args)
Definition: op.h:915
Symbol log1p(const std::string &symbol_name, Symbol data)
Definition: op.h:1868
SoftmaxActivationMode
Definition: op.h:6666
Symbol mp_sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol weight32, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4026
Symbol SpatialTransformer(const std::string &symbol_name, Symbol data, Symbol loc, SpatialTransformerTransformType transform_type, SpatialTransformerSamplerType sampler_type, Shape target_shape=Shape(0, 0))
Definition: op.h:5076
Symbol slice(const std::string &symbol_name, Symbol data, Shape begin, Shape end, Shape step=Shape())
Definition: op.h:408
Symbol exp(const std::string &symbol_name, Symbol data)
Definition: op.h:1780
Symbol transpose(const std::string &symbol_name, Symbol data, Shape axes=Shape())
Definition: op.h:320
Symbol clip(const std::string &symbol_name, Symbol data, mx_float a_min, mx_float a_max)
Definition: op.h:500
Symbol elemwise_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2577
Symbol ROIPooling(const std::string &symbol_name, Symbol data, Symbol rois, Shape pooled_size, mx_float spatial_scale)
Definition: op.h:6543
Symbol broadcast_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:855
Symbol nanprod(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2145
Convolution_v1Layout
Definition: op.h:4937
Symbol argmin(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:995
Symbol mp_sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, Symbol weight32, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4061
Symbol broadcast_lesser(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3078
Symbol fill_element_0index(const std::string &symbol_name, Symbol lhs, Symbol mhs, Symbol rhs)
Definition: op.h:6805
Symbol Convolution_v1(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, Convolution_v1CudnnTune cudnn_tune=Convolution_v1CudnnTune::kNone, bool cudnn_off=false, Convolution_v1Layout layout=Convolution_v1Layout::kNone)
Definition: op.h:4973
Symbol broadcast_not_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2985
TakeMode
Definition: op.h:2673
Symbol Embedding(const std::string &symbol_name, Symbol data, Symbol weight, int input_dim, int output_dim, EmbeddingDtype dtype=EmbeddingDtype::kFloat32)
Definition: op.h:2647
Symbol ftrl_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol z, Symbol n, mx_float lr, mx_float lamda1=0.01, mx_float beta=1, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4351
Symbol reciprocal(const std::string &symbol_name, Symbol data)
Definition: op.h:1401
TopkRetTyp
Definition: op.h:2338
Symbol RNN(const std::string &symbol_name, Symbol data, Symbol parameters, Symbol state, Symbol state_cell, uint32_t state_size, uint32_t num_layers, RNNMode mode, bool bidirectional=false, mx_float p=0, bool state_outputs=false)
Definition: op.h:4887
namespace of mxnet
Definition: base.h:127
Symbol reshape_like(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1310
Pooling_v1PoolingConvention
Definition: op.h:5951
Symbol broadcast_lesser_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3109
Operator & SetInput(const std::string &name, Symbol symbol)
add an input symbol
Symbol InstanceNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001)
Definition: op.h:4850
Symbol sign(const std::string &symbol_name, Symbol data)
Definition: op.h:1453
GridGeneratorTransformType
Definition: op.h:5911
Cast_storageStype
Definition: op.h:3181
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
Symbol ones_like(const std::string &symbol_name, Symbol data)
Definition: op.h:723
RNNMode
Definition: op.h:4865
PadMode
Definition: op.h:4379
Symbol smooth_l1(const std::string &symbol_name, Symbol data, mx_float scalar)
Definition: op.h:3170
Symbol where(const std::string &symbol_name, Symbol condition, Symbol x, Symbol y)
Definition: op.h:3134
Symbol Dropout(const std::string &symbol_name, Symbol data, mx_float p=0.5, DropoutMode mode=DropoutMode::kTraining)
Definition: op.h:6430
Symbol expm1(const std::string &symbol_name, Symbol data)
Definition: op.h:1892
Symbol elemwise_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2506
PoolingPoolType
Definition: op.h:6288
Symbol relu(const std::string &symbol_name, Symbol data)
Definition: op.h:1198
Symbol reverse(const std::string &symbol_name, Symbol data, Shape axis)
Definition: op.h:633
Symbol rsqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1706
Symbol Pooling(const std::string &symbol_name, Symbol data, Shape kernel, PoolingPoolType pool_type, bool global_pool=false, bool cudnn_off=false, PoolingPoolingConvention pooling_convention=PoolingPoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:6354
Symbol Pooling_v1(const std::string &symbol_name, Symbol data, Shape kernel, Pooling_v1PoolType pool_type, bool global_pool=false, Pooling_v1PoolingConvention pooling_convention=Pooling_v1PoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:6007
SpatialTransformerTransformType
Definition: op.h:5056
ActivationActType
Definition: op.h:6447
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1680
Symbol Softmax(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxNormalization normalization=SoftmaxNormalization::kNull, bool out_grad=false, mx_float smooth_alpha=0)
Definition: op.h:5362
Symbol rint(const std::string &symbol_name, Symbol data)
Definition: op.h:1509
Symbol IdentityAttachKLSparseReg(const std::string &symbol_name, Symbol data, mx_float sparseness_target=0.1, mx_float penalty=0.001, mx_float momentum=0.9)
Definition: op.h:4508
Symbol sinh(const std::string &symbol_name, Symbol data)
Definition: op.h:3470
Symbol scatter_nd(const std::string &symbol_name, Symbol data, Symbol indices, Shape shape)
Definition: op.h:2921
Symbol broadcast_greater_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3047
Symbol LRN(const std::string &symbol_name, Symbol data, uint32_t nsize, mx_float alpha=0.0001, mx_float beta=0.75, mx_float knorm=2)
Definition: op.h:5753
Symbol sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:3991
Symbol max(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2182
Symbol arcsinh(const std::string &symbol_name, Symbol data)
Definition: op.h:3540
Symbol MAERegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:6616
Symbol sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:3932
Symbol SliceChannel(const std::string &symbol_name, Symbol data, int num_outputs, int axis=1, bool squeeze_axis=false)
Definition: op.h:4592
PoolingPoolingConvention
Definition: op.h:6296
Symbol broadcast_minimum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:110
Symbol broadcast_maximum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:77
Symbol Cast(const std::string &symbol_name, Symbol data, CastDtype dtype)
Definition: op.h:1348
DeconvolutionLayout
Definition: op.h:5108
Symbol trunc(const std::string &symbol_name, Symbol data)
Definition: op.h:1594
Pooling_v1PoolType
Definition: op.h:5943
Symbol round(const std::string &symbol_name, Symbol data)
Definition: op.h:1479
Symbol log_softmax(const std::string &symbol_name, Symbol data, int axis=-1)
Definition: op.h:3673
Symbol cos(const std::string &symbol_name, Symbol data)
Definition: op.h:3288
Symbol SequenceMask(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, mx_float value=0)
Definition: op.h:5895
Symbol L2Normalization(const std::string &symbol_name, Symbol data, mx_float eps=1e-10, L2NormalizationMode mode=L2NormalizationMode::kInstance)
Definition: op.h:5710
Symbol Correlation(const std::string &symbol_name, Symbol data1, Symbol data2, uint32_t kernel_size=1, uint32_t max_displacement=1, uint32_t stride1=1, uint32_t stride2=1, uint32_t pad_size=0, bool is_multiply=true)
Definition: op.h:5589
Symbol zeros_like(const std::string &symbol_name, Symbol data)
Definition: op.h:699
EmbeddingDtype
Definition: op.h:2588
Symbol batch_dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false)
Definition: op.h:1167
Symbol broadcast_mod(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:886
Symbol cbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1730
Symbol prod(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2067
operator helper functions
Symbol mean(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2030
Symbol tanh(const std::string &symbol_name, Symbol data)
Definition: op.h:3517
Symbol broadcast_to(const std::string &symbol_name, Symbol data, Shape shape=Shape())
Definition: op.h:2300
Symbol elemwise_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2530
DropoutMode
Definition: op.h:6386
Symbol MakeLoss(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float valid_thresh=0, MakeLossNormalization normalization=MakeLossNormalization::kNull)
Definition: op.h:6761
Symbol log(const std::string &symbol_name, Symbol data)
Definition: op.h:1801
Symbol sigmoid(const std::string &symbol_name, Symbol data)
Definition: op.h:1220
Symbol SequenceReverse(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false)
Definition: op.h:5465
CastDtype
Definition: op.h:1321
ConvolutionLayout
Definition: op.h:6047
Symbol LogisticRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:6652
Symbol gamma(const std::string &symbol_name, Symbol data)
Definition: op.h:1910
Symbol sin(const std::string &symbol_name, Symbol data)
Definition: op.h:3264
UpSamplingMultiInputMode
Definition: op.h:4661
Symbol CreateSymbol(const std::string &name="")
create a Symbol from the current operator
Symbol elemwise_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2557
SpatialTransformerSamplerType
Definition: op.h:5062
Symbol Pad(const std::string &symbol_name, Symbol data, PadMode mode, Shape pad_width, double constant_value=0)
Definition: op.h:4481
Symbol square(const std::string &symbol_name, Symbol data)
Definition: op.h:1651
One_hotDtype
Definition: op.h:2775
UpSamplingSampleType
Definition: op.h:4653
Symbol norm(const std::string &symbol_name, Symbol data)
Definition: op.h:2326
Symbol rmsprop_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, mx_float lr, mx_float gamma1=0.95, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:4211
Symbol LeakyReLU(const std::string &symbol_name, Symbol data, LeakyReLUActType act_type=LeakyReLUActType::kLeaky, mx_float slope=0.25, mx_float lower_bound=0.125, mx_float upper_bound=0.334)
Definition: op.h:3717
Symbol make_loss(const std::string &symbol_name, Symbol data)
Definition: op.h:1296
Symbol SoftmaxActivation(const std::string &symbol_name, Symbol data, SoftmaxActivationMode mode=SoftmaxActivationMode::kInstance)
Definition: op.h:6704
Symbol broadcast_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2954
Symbol Deconvolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), Shape adj=Shape(), Shape target_shape=Shape(), uint32_t num_group=1, uint64_t workspace=512, bool no_bias=true, DeconvolutionCudnnTune cudnn_tune=DeconvolutionCudnnTune::kNone, bool cudnn_off=false, DeconvolutionLayout layout=DeconvolutionLayout::kNone)
Definition: op.h:5143
Symbol broadcast_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:757
Symbol adam_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mean, Symbol var, mx_float lr, mx_float beta1=0.9, mx_float beta2=0.999, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4131
Operator & SetParam(const std::string &name, const T &value)
set config parameters
Definition: operator.h:58
Symbol tan(const std::string &symbol_name, Symbol data)
Definition: op.h:3315
Convolution_v1CudnnTune
Definition: op.h:4927
Symbol repeat(const std::string &symbol_name, Symbol data, int repeats, dmlc::optional< int > axis=dmlc::optional< int >())
Definition: op.h:545
Symbol slice_axis(const std::string &symbol_name, Symbol data, int axis, int begin, dmlc::optional< int > end)
Definition: op.h:453
Symbol expand_dims(const std::string &symbol_name, Symbol data, int axis)
Definition: op.h:344
Symbol arctanh(const std::string &symbol_name, Symbol data)
Definition: op.h:3583
Symbol softmax_cross_entropy(const std::string &symbol_name, Symbol data, Symbol label)
Definition: op.h:4642
Symbol pick(const std::string &symbol_name, Symbol data, Symbol index, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1080
Symbol broadcast_axis(const std::string &symbol_name, Symbol data, Shape axis=Shape(), Shape size=Shape())
Definition: op.h:2261
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:1427
Symbol cosh(const std::string &symbol_name, Symbol data)
Definition: op.h:3492
Symbol sort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Definition: op.h:2439
Symbol gather_nd(const std::string &symbol_name, Symbol data, Symbol indices)
Definition: op.h:2876
Symbol BilinearSampler(const std::string &symbol_name, Symbol data, Symbol grid)
Definition: op.h:6277
Symbol Custom(const std::string &symbol_name, const std::vector< Symbol > &data, const std::string &op_type)
Definition: op.h:3605
Symbol broadcast_hypot(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:149
Symbol BatchNorm_v1(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false)
Definition: op.h:3832
Symbol UpSampling(const std::string &symbol_name, const std::vector< Symbol > &data, uint32_t scale, UpSamplingSampleType sample_type, int num_args, uint32_t num_filter=0, UpSamplingMultiInputMode multi_input_mode=UpSamplingMultiInputMode::kConcat, uint64_t workspace=512)
Definition: op.h:4681
Symbol Activation(const std::string &symbol_name, Symbol data, ActivationActType act_type)
Definition: op.h:6472
float mx_float
manually define float
Definition: c_api.h:60
Symbol SVMOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float margin=1, mx_float regularization_coefficient=1, bool use_linear=false)
Definition: op.h:5625
Symbol radians(const std::string &symbol_name, Symbol data)
Definition: op.h:3445
Symbol Concat(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int dim=1)
Definition: op.h:3894
L2NormalizationMode
Definition: op.h:5642
Symbol stack(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int axis=0)
Definition: op.h:665
Symbol floor(const std::string &symbol_name, Symbol data)
Definition: op.h:1565
Symbol broadcast_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:793
Symbol take(const std::string &symbol_name, Symbol a, Symbol indices, int axis=0, TakeMode mode=TakeMode::kClip)
Definition: op.h:2718
Symbol ceil(const std::string &symbol_name, Symbol data)
Definition: op.h:1537
Symbol gammaln(const std::string &symbol_name, Symbol data)
Definition: op.h:1928
Symbol tile(const std::string &symbol_name, Symbol data, Shape reps)
Definition: op.h:601
Symbol rmspropalex_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, Symbol g, Symbol delta, mx_float lr, mx_float gamma1=0.95, mx_float gamma2=0.9, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:4281
Symbol argsort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Definition: op.h:2480
SoftmaxNormalization
Definition: op.h:5329
Symbol softmax(const std::string &symbol_name, Symbol data, int axis=-1)
Definition: op.h:3642
DeconvolutionCudnnTune
Definition: op.h:5099
ConvolutionCudnnTune
Definition: op.h:6037
definition of shape
Symbol broadcast_greater(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3016
Symbol BatchNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, Symbol moving_mean, Symbol moving_var, double eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false, int axis=1, bool cudnn_off=false)
Definition: op.h:4771
Symbol rcbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1754
Symbol topk(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), int k=1, TopkRetTyp ret_typ=TopkRetTyp::kIndices, bool is_ascend=false)
Definition: op.h:2386
Symbol broadcast_power(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:44
SoftmaxOutputNormalization
Definition: op.h:5196
Symbol SoftmaxOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxOutputNormalization normalization=SoftmaxOutputNormalization::kNull, bool out_grad=false, mx_float smooth_alpha=0)
Definition: op.h:5297
Symbol Flatten(const std::string &symbol_name, Symbol data)
Definition: op.h:277
Symbol BlockGrad(const std::string &symbol_name, Symbol data)
Definition: op.h:1260
LeakyReLUActType
Definition: op.h:3684
Symbol arccos(const std::string &symbol_name, Symbol data)
Definition: op.h:3368
Symbol argmax_channel(const std::string &symbol_name, Symbol data)
Definition: op.h:1028
Symbol batch_take(const std::string &symbol_name, Symbol a, Symbol indices)
Definition: op.h:2764
Symbol LinearRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:6580
Symbol choose_element_0index(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:6787
Symbol Reshape(const std::string &symbol_name, Symbol data, Shape shape=Shape(), bool reverse=false, Shape target_shape=Shape(), bool keep_highest=false)
Definition: op.h:232
Symbol degrees(const std::string &symbol_name, Symbol data)
Definition: op.h:3420
Symbol one_hot(const std::string &symbol_name, Symbol indices, int depth, double on_value=1, double off_value=0, One_hotDtype dtype=One_hotDtype::kFloat32)
Definition: op.h:2827
Symbol SequenceLast(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false)
Definition: op.h:5530
Symbol negative(const std::string &symbol_name, Symbol data)
Definition: op.h:1378
Symbol GridGenerator(const std::string &symbol_name, Symbol data, GridGeneratorTransformType transform_type, Shape target_shape=Shape(0, 0))
Definition: op.h:5926
Symbol argmax(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:953
Operator interface.
Definition: operator.h:43
Symbol interface.
Definition: symbol.h:72
Symbol sum(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:1993
MakeLossNormalization
Definition: op.h:6721
Symbol log10(const std::string &symbol_name, Symbol data)
Definition: op.h:1822
Symbol dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false)
Definition: op.h:1132
Symbol log2(const std::string &symbol_name, Symbol data)
Definition: op.h:1843