mxnet
op.h
Go to the documentation of this file.
1 
8 #ifndef MXNET_CPP_OP_H_
9 #define MXNET_CPP_OP_H_
10 
11 #include <string>
12 #include <vector>
13 #include "mxnet-cpp/base.h"
14 #include "mxnet-cpp/shape.h"
15 #include "mxnet-cpp/op_util.h"
16 #include "mxnet-cpp/operator.h"
17 #include "dmlc/optional.h"
18 
19 namespace mxnet {
20 namespace cpp {
21 
62 inline Symbol khatri_rao(const std::string& symbol_name,
63  const std::vector<Symbol>& args) {
64  return Operator("khatri_rao")
65 (args)
66  .CreateSymbol(symbol_name);
67 }
68 
84 inline Symbol Custom(const std::string& symbol_name,
85  const std::vector<Symbol>& data,
86  const std::string& op_type) {
87  return Operator("Custom")
88 (data)
89  .CreateSymbol(symbol_name);
90 }
91 
114 inline Symbol broadcast_power(const std::string& symbol_name,
115  Symbol lhs,
116  Symbol rhs) {
117  return Operator("broadcast_power")
118  .SetInput("lhs", lhs)
119  .SetInput("rhs", rhs)
120  .CreateSymbol(symbol_name);
121 }
122 
147 inline Symbol broadcast_maximum(const std::string& symbol_name,
148  Symbol lhs,
149  Symbol rhs) {
150  return Operator("broadcast_maximum")
151  .SetInput("lhs", lhs)
152  .SetInput("rhs", rhs)
153  .CreateSymbol(symbol_name);
154 }
155 
180 inline Symbol broadcast_minimum(const std::string& symbol_name,
181  Symbol lhs,
182  Symbol rhs) {
183  return Operator("broadcast_minimum")
184  .SetInput("lhs", lhs)
185  .SetInput("rhs", rhs)
186  .CreateSymbol(symbol_name);
187 }
188 
219 inline Symbol broadcast_hypot(const std::string& symbol_name,
220  Symbol lhs,
221  Symbol rhs) {
222  return Operator("broadcast_hypot")
223  .SetInput("lhs", lhs)
224  .SetInput("rhs", rhs)
225  .CreateSymbol(symbol_name);
226 }
227 
302 inline Symbol Reshape(const std::string& symbol_name,
303  Symbol data,
304  Shape shape = Shape(),
305  bool reverse = false,
306  Shape target_shape = Shape(),
307  bool keep_highest = false) {
308  return Operator("Reshape")
309  .SetParam("shape", shape)
310  .SetParam("reverse", reverse)
311  .SetParam("target_shape", target_shape)
312  .SetParam("keep_highest", keep_highest)
313  .SetInput("data", data)
314  .CreateSymbol(symbol_name);
315 }
316 
350 inline Symbol Flatten(const std::string& symbol_name,
351  Symbol data) {
352  return Operator("Flatten")
353  .SetInput("data", data)
354  .CreateSymbol(symbol_name);
355 }
356 
393 inline Symbol transpose(const std::string& symbol_name,
394  Symbol data,
395  Shape axes = Shape()) {
396  return Operator("transpose")
397  .SetParam("axes", axes)
398  .SetInput("data", data)
399  .CreateSymbol(symbol_name);
400 }
401 
417 inline Symbol expand_dims(const std::string& symbol_name,
418  Symbol data,
419  int axis) {
420  return Operator("expand_dims")
421  .SetParam("axis", axis)
422  .SetInput("data", data)
423  .CreateSymbol(symbol_name);
424 }
425 
481 inline Symbol slice(const std::string& symbol_name,
482  Symbol data,
483  Shape begin,
484  Shape end,
485  Shape step = Shape()) {
486  return Operator("slice")
487  .SetParam("begin", begin)
488  .SetParam("end", end)
489  .SetParam("step", step)
490  .SetInput("data", data)
491  .CreateSymbol(symbol_name);
492 }
493 
526 inline Symbol slice_axis(const std::string& symbol_name,
527  Symbol data,
528  int axis,
529  int begin,
530  dmlc::optional<int> end) {
531  return Operator("slice_axis")
532  .SetParam("axis", axis)
533  .SetParam("begin", begin)
534  .SetParam("end", end)
535  .SetInput("data", data)
536  .CreateSymbol(symbol_name);
537 }
538 
600 inline Symbol slice_like(const std::string& symbol_name,
601  Symbol data,
602  Symbol shape_like,
603  Shape axes = Shape()) {
604  return Operator("slice_like")
605  .SetParam("axes", axes)
606  .SetInput("data", data)
607  .SetInput("shape_like", shape_like)
608  .CreateSymbol(symbol_name);
609 }
610 
645 inline Symbol clip(const std::string& symbol_name,
646  Symbol data,
647  mx_float a_min,
648  mx_float a_max) {
649  return Operator("clip")
650  .SetParam("a_min", a_min)
651  .SetParam("a_max", a_max)
652  .SetInput("data", data)
653  .CreateSymbol(symbol_name);
654 }
655 
690 inline Symbol repeat(const std::string& symbol_name,
691  Symbol data,
692  int repeats,
693  dmlc::optional<int> axis = dmlc::optional<int>()) {
694  return Operator("repeat")
695  .SetParam("repeats", repeats)
696  .SetParam("axis", axis)
697  .SetInput("data", data)
698  .CreateSymbol(symbol_name);
699 }
700 
746 inline Symbol tile(const std::string& symbol_name,
747  Symbol data,
748  Shape reps) {
749  return Operator("tile")
750  .SetParam("reps", reps)
751  .SetInput("data", data)
752  .CreateSymbol(symbol_name);
753 }
754 
778 inline Symbol reverse(const std::string& symbol_name,
779  Symbol data,
780  Shape axis) {
781  return Operator("reverse")
782  .SetParam("axis", axis)
783  .SetInput("data", data)
784  .CreateSymbol(symbol_name);
785 }
786 
810 inline Symbol stack(const std::string& symbol_name,
811  const std::vector<Symbol>& data,
812  int num_args,
813  int axis = 0) {
814  return Operator("stack")
815  .SetParam("num_args", num_args)
816  .SetParam("axis", axis)
817 (data)
818  .CreateSymbol(symbol_name);
819 }
820 
843 inline Symbol squeeze(const std::string& symbol_name,
844  const std::vector<Symbol>& data,
845  dmlc::optional<Shape> axis = dmlc::optional<Shape>()) {
846  return Operator("squeeze")
847  .SetParam("axis", axis)
848 (data)
849  .CreateSymbol(symbol_name);
850 }
851 
875 inline Symbol zeros_like(const std::string& symbol_name,
876  Symbol data) {
877  return Operator("zeros_like")
878  .SetInput("data", data)
879  .CreateSymbol(symbol_name);
880 }
881 
899 inline Symbol ones_like(const std::string& symbol_name,
900  Symbol data) {
901  return Operator("ones_like")
902  .SetInput("data", data)
903  .CreateSymbol(symbol_name);
904 }
905 
933 inline Symbol broadcast_add(const std::string& symbol_name,
934  Symbol lhs,
935  Symbol rhs) {
936  return Operator("broadcast_add")
937  .SetInput("lhs", lhs)
938  .SetInput("rhs", rhs)
939  .CreateSymbol(symbol_name);
940 }
941 
969 inline Symbol broadcast_sub(const std::string& symbol_name,
970  Symbol lhs,
971  Symbol rhs) {
972  return Operator("broadcast_sub")
973  .SetInput("lhs", lhs)
974  .SetInput("rhs", rhs)
975  .CreateSymbol(symbol_name);
976 }
977 
1003 inline Symbol broadcast_mul(const std::string& symbol_name,
1004  Symbol lhs,
1005  Symbol rhs) {
1006  return Operator("broadcast_mul")
1007  .SetInput("lhs", lhs)
1008  .SetInput("rhs", rhs)
1009  .CreateSymbol(symbol_name);
1010 }
1011 
1037 inline Symbol broadcast_div(const std::string& symbol_name,
1038  Symbol lhs,
1039  Symbol rhs) {
1040  return Operator("broadcast_div")
1041  .SetInput("lhs", lhs)
1042  .SetInput("rhs", rhs)
1043  .CreateSymbol(symbol_name);
1044 }
1045 
1068 inline Symbol broadcast_mod(const std::string& symbol_name,
1069  Symbol lhs,
1070  Symbol rhs) {
1071  return Operator("broadcast_mod")
1072  .SetInput("lhs", lhs)
1073  .SetInput("rhs", rhs)
1074  .CreateSymbol(symbol_name);
1075 }
1076 
1097 inline Symbol add_n(const std::string& symbol_name,
1098  const std::vector<Symbol>& args) {
1099  return Operator("add_n")
1100 (args)
1101  .CreateSymbol(symbol_name);
1102 }
1103 
1135 inline Symbol argmax(const std::string& symbol_name,
1136  Symbol data,
1137  dmlc::optional<int> axis = dmlc::optional<int>(),
1138  bool keepdims = false) {
1139  return Operator("argmax")
1140  .SetParam("axis", axis)
1141  .SetParam("keepdims", keepdims)
1142  .SetInput("data", data)
1143  .CreateSymbol(symbol_name);
1144 }
1145 
1177 inline Symbol argmin(const std::string& symbol_name,
1178  Symbol data,
1179  dmlc::optional<int> axis = dmlc::optional<int>(),
1180  bool keepdims = false) {
1181  return Operator("argmin")
1182  .SetParam("axis", axis)
1183  .SetParam("keepdims", keepdims)
1184  .SetInput("data", data)
1185  .CreateSymbol(symbol_name);
1186 }
1187 
1210 inline Symbol argmax_channel(const std::string& symbol_name,
1211  Symbol data) {
1212  return Operator("argmax_channel")
1213  .SetInput("data", data)
1214  .CreateSymbol(symbol_name);
1215 }
1216 
1262 inline Symbol pick(const std::string& symbol_name,
1263  Symbol data,
1264  Symbol index,
1265  dmlc::optional<int> axis = dmlc::optional<int>(),
1266  bool keepdims = false) {
1267  return Operator("pick")
1268  .SetParam("axis", axis)
1269  .SetParam("keepdims", keepdims)
1270  .SetInput("data", data)
1271  .SetInput("index", index)
1272  .CreateSymbol(symbol_name);
1273 }
1274 
1315 inline Symbol dot(const std::string& symbol_name,
1316  Symbol lhs,
1317  Symbol rhs,
1318  bool transpose_a = false,
1319  bool transpose_b = false) {
1320  return Operator("dot")
1321  .SetParam("transpose_a", transpose_a)
1322  .SetParam("transpose_b", transpose_b)
1323  .SetInput("lhs", lhs)
1324  .SetInput("rhs", rhs)
1325  .CreateSymbol(symbol_name);
1326 }
1327 
1350 inline Symbol batch_dot(const std::string& symbol_name,
1351  Symbol lhs,
1352  Symbol rhs,
1353  bool transpose_a = false,
1354  bool transpose_b = false) {
1355  return Operator("batch_dot")
1356  .SetParam("transpose_a", transpose_a)
1357  .SetParam("transpose_b", transpose_b)
1358  .SetInput("lhs", lhs)
1359  .SetInput("rhs", rhs)
1360  .CreateSymbol(symbol_name);
1361 }
1362 
1381 inline Symbol relu(const std::string& symbol_name,
1382  Symbol data) {
1383  return Operator("relu")
1384  .SetInput("data", data)
1385  .CreateSymbol(symbol_name);
1386 }
1387 
1403 inline Symbol sigmoid(const std::string& symbol_name,
1404  Symbol data) {
1405  return Operator("sigmoid")
1406  .SetInput("data", data)
1407  .CreateSymbol(symbol_name);
1408 }
1409 
1425 inline Symbol softsign(const std::string& symbol_name,
1426  Symbol data) {
1427  return Operator("softsign")
1428  .SetInput("data", data)
1429  .CreateSymbol(symbol_name);
1430 }
1431 
1465 inline Symbol BlockGrad(const std::string& symbol_name,
1466  Symbol data) {
1467  return Operator("BlockGrad")
1468  .SetInput("data", data)
1469  .CreateSymbol(symbol_name);
1470 }
1471 
1501 inline Symbol make_loss(const std::string& symbol_name,
1502  Symbol data) {
1503  return Operator("make_loss")
1504  .SetInput("data", data)
1505  .CreateSymbol(symbol_name);
1506 }
1507 
1515 inline Symbol reshape_like(const std::string& symbol_name,
1516  Symbol lhs,
1517  Symbol rhs) {
1518  return Operator("reshape_like")
1519  .SetInput("lhs", lhs)
1520  .SetInput("rhs", rhs)
1521  .CreateSymbol(symbol_name);
1522 }
1523 
1526 enum class CastDtype {
1527  kFloat16 = 0,
1528  kFloat32 = 1,
1529  kFloat64 = 2,
1530  kInt32 = 3,
1531  kInt64 = 4,
1532  kInt8 = 5,
1533  kUint8 = 6
1534 };
1535 
1555 inline Symbol Cast(const std::string& symbol_name,
1556  Symbol data,
1557  CastDtype dtype) {
1558  static const char *CastDtypeValues[] = {
1559  "float16",
1560  "float32",
1561  "float64",
1562  "int32",
1563  "int64",
1564  "int8",
1565  "uint8"
1566  };
1567  return Operator("Cast")
1568  .SetParam("dtype", CastDtypeValues[int(dtype)])
1569  .SetInput("data", data)
1570  .CreateSymbol(symbol_name);
1571 }
1572 
1587 inline Symbol negative(const std::string& symbol_name,
1588  Symbol data) {
1589  return Operator("negative")
1590  .SetInput("data", data)
1591  .CreateSymbol(symbol_name);
1592 }
1593 
1610 inline Symbol reciprocal(const std::string& symbol_name,
1611  Symbol data) {
1612  return Operator("reciprocal")
1613  .SetInput("data", data)
1614  .CreateSymbol(symbol_name);
1615 }
1616 
1636 inline Symbol abs(const std::string& symbol_name,
1637  Symbol data) {
1638  return Operator("abs")
1639  .SetInput("data", data)
1640  .CreateSymbol(symbol_name);
1641 }
1642 
1662 inline Symbol sign(const std::string& symbol_name,
1663  Symbol data) {
1664  return Operator("sign")
1665  .SetInput("data", data)
1666  .CreateSymbol(symbol_name);
1667 }
1668 
1688 inline Symbol round(const std::string& symbol_name,
1689  Symbol data) {
1690  return Operator("round")
1691  .SetInput("data", data)
1692  .CreateSymbol(symbol_name);
1693 }
1694 
1718 inline Symbol rint(const std::string& symbol_name,
1719  Symbol data) {
1720  return Operator("rint")
1721  .SetInput("data", data)
1722  .CreateSymbol(symbol_name);
1723 }
1724 
1746 inline Symbol ceil(const std::string& symbol_name,
1747  Symbol data) {
1748  return Operator("ceil")
1749  .SetInput("data", data)
1750  .CreateSymbol(symbol_name);
1751 }
1752 
1774 inline Symbol floor(const std::string& symbol_name,
1775  Symbol data) {
1776  return Operator("floor")
1777  .SetInput("data", data)
1778  .CreateSymbol(symbol_name);
1779 }
1780 
1803 inline Symbol trunc(const std::string& symbol_name,
1804  Symbol data) {
1805  return Operator("trunc")
1806  .SetInput("data", data)
1807  .CreateSymbol(symbol_name);
1808 }
1809 
1830 inline Symbol fix(const std::string& symbol_name,
1831  Symbol data) {
1832  return Operator("fix")
1833  .SetInput("data", data)
1834  .CreateSymbol(symbol_name);
1835 }
1836 
1860 inline Symbol square(const std::string& symbol_name,
1861  Symbol data) {
1862  return Operator("square")
1863  .SetInput("data", data)
1864  .CreateSymbol(symbol_name);
1865 }
1866 
1889 inline Symbol sqrt(const std::string& symbol_name,
1890  Symbol data) {
1891  return Operator("sqrt")
1892  .SetInput("data", data)
1893  .CreateSymbol(symbol_name);
1894 }
1895 
1915 inline Symbol rsqrt(const std::string& symbol_name,
1916  Symbol data) {
1917  return Operator("rsqrt")
1918  .SetInput("data", data)
1919  .CreateSymbol(symbol_name);
1920 }
1921 
1939 inline Symbol cbrt(const std::string& symbol_name,
1940  Symbol data) {
1941  return Operator("cbrt")
1942  .SetInput("data", data)
1943  .CreateSymbol(symbol_name);
1944 }
1945 
1963 inline Symbol rcbrt(const std::string& symbol_name,
1964  Symbol data) {
1965  return Operator("rcbrt")
1966  .SetInput("data", data)
1967  .CreateSymbol(symbol_name);
1968 }
1969 
1989 inline Symbol exp(const std::string& symbol_name,
1990  Symbol data) {
1991  return Operator("exp")
1992  .SetInput("data", data)
1993  .CreateSymbol(symbol_name);
1994 }
1995 
2010 inline Symbol log(const std::string& symbol_name,
2011  Symbol data) {
2012  return Operator("log")
2013  .SetInput("data", data)
2014  .CreateSymbol(symbol_name);
2015 }
2016 
2031 inline Symbol log10(const std::string& symbol_name,
2032  Symbol data) {
2033  return Operator("log10")
2034  .SetInput("data", data)
2035  .CreateSymbol(symbol_name);
2036 }
2037 
2052 inline Symbol log2(const std::string& symbol_name,
2053  Symbol data) {
2054  return Operator("log2")
2055  .SetInput("data", data)
2056  .CreateSymbol(symbol_name);
2057 }
2058 
2077 inline Symbol log1p(const std::string& symbol_name,
2078  Symbol data) {
2079  return Operator("log1p")
2080  .SetInput("data", data)
2081  .CreateSymbol(symbol_name);
2082 }
2083 
2101 inline Symbol expm1(const std::string& symbol_name,
2102  Symbol data) {
2103  return Operator("expm1")
2104  .SetInput("data", data)
2105  .CreateSymbol(symbol_name);
2106 }
2107 
2119 inline Symbol gamma(const std::string& symbol_name,
2120  Symbol data) {
2121  return Operator("gamma")
2122  .SetInput("data", data)
2123  .CreateSymbol(symbol_name);
2124 }
2125 
2137 inline Symbol gammaln(const std::string& symbol_name,
2138  Symbol data) {
2139  return Operator("gammaln")
2140  .SetInput("data", data)
2141  .CreateSymbol(symbol_name);
2142 }
2143 
2202 inline Symbol sum(const std::string& symbol_name,
2203  Symbol data,
2204  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2205  bool keepdims = false,
2206  bool exclude = false) {
2207  return Operator("sum")
2208  .SetParam("axis", axis)
2209  .SetParam("keepdims", keepdims)
2210  .SetParam("exclude", exclude)
2211  .SetInput("data", data)
2212  .CreateSymbol(symbol_name);
2213 }
2214 
2239 inline Symbol mean(const std::string& symbol_name,
2240  Symbol data,
2241  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2242  bool keepdims = false,
2243  bool exclude = false) {
2244  return Operator("mean")
2245  .SetParam("axis", axis)
2246  .SetParam("keepdims", keepdims)
2247  .SetParam("exclude", exclude)
2248  .SetInput("data", data)
2249  .CreateSymbol(symbol_name);
2250 }
2251 
2276 inline Symbol prod(const std::string& symbol_name,
2277  Symbol data,
2278  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2279  bool keepdims = false,
2280  bool exclude = false) {
2281  return Operator("prod")
2282  .SetParam("axis", axis)
2283  .SetParam("keepdims", keepdims)
2284  .SetParam("exclude", exclude)
2285  .SetInput("data", data)
2286  .CreateSymbol(symbol_name);
2287 }
2288 
2315 inline Symbol nansum(const std::string& symbol_name,
2316  Symbol data,
2317  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2318  bool keepdims = false,
2319  bool exclude = false) {
2320  return Operator("nansum")
2321  .SetParam("axis", axis)
2322  .SetParam("keepdims", keepdims)
2323  .SetParam("exclude", exclude)
2324  .SetInput("data", data)
2325  .CreateSymbol(symbol_name);
2326 }
2327 
2354 inline Symbol nanprod(const std::string& symbol_name,
2355  Symbol data,
2356  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2357  bool keepdims = false,
2358  bool exclude = false) {
2359  return Operator("nanprod")
2360  .SetParam("axis", axis)
2361  .SetParam("keepdims", keepdims)
2362  .SetParam("exclude", exclude)
2363  .SetInput("data", data)
2364  .CreateSymbol(symbol_name);
2365 }
2366 
2391 inline Symbol max(const std::string& symbol_name,
2392  Symbol data,
2393  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2394  bool keepdims = false,
2395  bool exclude = false) {
2396  return Operator("max")
2397  .SetParam("axis", axis)
2398  .SetParam("keepdims", keepdims)
2399  .SetParam("exclude", exclude)
2400  .SetInput("data", data)
2401  .CreateSymbol(symbol_name);
2402 }
2403 
2428 inline Symbol min(const std::string& symbol_name,
2429  Symbol data,
2430  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2431  bool keepdims = false,
2432  bool exclude = false) {
2433  return Operator("min")
2434  .SetParam("axis", axis)
2435  .SetParam("keepdims", keepdims)
2436  .SetParam("exclude", exclude)
2437  .SetInput("data", data)
2438  .CreateSymbol(symbol_name);
2439 }
2440 
2470 inline Symbol broadcast_axis(const std::string& symbol_name,
2471  Symbol data,
2472  Shape axis = Shape(),
2473  Shape size = Shape()) {
2474  return Operator("broadcast_axis")
2475  .SetParam("axis", axis)
2476  .SetParam("size", size)
2477  .SetInput("data", data)
2478  .CreateSymbol(symbol_name);
2479 }
2480 
2509 inline Symbol broadcast_to(const std::string& symbol_name,
2510  Symbol data,
2511  Shape shape = Shape()) {
2512  return Operator("broadcast_to")
2513  .SetParam("shape", shape)
2514  .SetInput("data", data)
2515  .CreateSymbol(symbol_name);
2516 }
2517 
2555 inline Symbol norm(const std::string& symbol_name,
2556  Symbol data,
2557  int ord = 2,
2558  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2559  bool keepdims = false) {
2560  return Operator("norm")
2561  .SetParam("ord", ord)
2562  .SetParam("axis", axis)
2563  .SetParam("keepdims", keepdims)
2564  .SetInput("data", data)
2565  .CreateSymbol(symbol_name);
2566 }
2567 
2573 enum class TopkRetTyp {
2574  kBoth = 0,
2575  kIndices = 1,
2576  kMask = 2,
2577  kValue = 3
2578 };
2579 
2621 inline Symbol topk(const std::string& symbol_name,
2622  Symbol data,
2623  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2624  int k = 1,
2625  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
2626  bool is_ascend = false) {
2627  static const char *TopkRetTypValues[] = {
2628  "both",
2629  "indices",
2630  "mask",
2631  "value"
2632  };
2633  return Operator("topk")
2634  .SetParam("axis", axis)
2635  .SetParam("k", k)
2636  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
2637  .SetParam("is_ascend", is_ascend)
2638  .SetInput("data", data)
2639  .CreateSymbol(symbol_name);
2640 }
2641 
2674 inline Symbol sort(const std::string& symbol_name,
2675  Symbol data,
2676  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2677  bool is_ascend = true) {
2678  return Operator("sort")
2679  .SetParam("axis", axis)
2680  .SetParam("is_ascend", is_ascend)
2681  .SetInput("data", data)
2682  .CreateSymbol(symbol_name);
2683 }
2684 
2715 inline Symbol argsort(const std::string& symbol_name,
2716  Symbol data,
2717  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2718  bool is_ascend = true) {
2719  return Operator("argsort")
2720  .SetParam("axis", axis)
2721  .SetParam("is_ascend", is_ascend)
2722  .SetInput("data", data)
2723  .CreateSymbol(symbol_name);
2724 }
2725 
2741 inline Symbol elemwise_add(const std::string& symbol_name,
2742  Symbol lhs,
2743  Symbol rhs) {
2744  return Operator("elemwise_add")
2745  .SetInput("lhs", lhs)
2746  .SetInput("rhs", rhs)
2747  .CreateSymbol(symbol_name);
2748 }
2749 
2765 inline Symbol elemwise_sub(const std::string& symbol_name,
2766  Symbol lhs,
2767  Symbol rhs) {
2768  return Operator("elemwise_sub")
2769  .SetInput("lhs", lhs)
2770  .SetInput("rhs", rhs)
2771  .CreateSymbol(symbol_name);
2772 }
2773 
2792 inline Symbol elemwise_mul(const std::string& symbol_name,
2793  Symbol lhs,
2794  Symbol rhs) {
2795  return Operator("elemwise_mul")
2796  .SetInput("lhs", lhs)
2797  .SetInput("rhs", rhs)
2798  .CreateSymbol(symbol_name);
2799 }
2800 
2812 inline Symbol elemwise_div(const std::string& symbol_name,
2813  Symbol lhs,
2814  Symbol rhs) {
2815  return Operator("elemwise_div")
2816  .SetInput("lhs", lhs)
2817  .SetInput("rhs", rhs)
2818  .CreateSymbol(symbol_name);
2819 }
2820 
2823 enum class EmbeddingDtype {
2824  kFloat16 = 0,
2825  kFloat32 = 1,
2826  kFloat64 = 2,
2827  kInt32 = 3,
2828  kInt64 = 4,
2829  kInt8 = 5,
2830  kUint8 = 6
2831 };
2832 
2884 inline Symbol Embedding(const std::string& symbol_name,
2885  Symbol data,
2886  Symbol weight,
2887  int input_dim,
2888  int output_dim,
2890  static const char *EmbeddingDtypeValues[] = {
2891  "float16",
2892  "float32",
2893  "float64",
2894  "int32",
2895  "int64",
2896  "int8",
2897  "uint8"
2898  };
2899  return Operator("Embedding")
2900  .SetParam("input_dim", input_dim)
2901  .SetParam("output_dim", output_dim)
2902  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
2903  .SetInput("data", data)
2904  .SetInput("weight", weight)
2905  .CreateSymbol(symbol_name);
2906 }
2907 
2912 enum class TakeMode {
2913  kClip = 0,
2914  kRaise = 1,
2915  kWrap = 2
2916 };
2917 
2961 inline Symbol take(const std::string& symbol_name,
2962  Symbol a,
2963  Symbol indices,
2964  int axis = 0,
2965  TakeMode mode = TakeMode::kClip) {
2966  static const char *TakeModeValues[] = {
2967  "clip",
2968  "raise",
2969  "wrap"
2970  };
2971  return Operator("take")
2972  .SetParam("axis", axis)
2973  .SetParam("mode", TakeModeValues[int(mode)])
2974  .SetInput("a", a)
2975  .SetInput("indices", indices)
2976  .CreateSymbol(symbol_name);
2977 }
2978 
3007 inline Symbol batch_take(const std::string& symbol_name,
3008  Symbol a,
3009  Symbol indices) {
3010  return Operator("batch_take")
3011  .SetInput("a", a)
3012  .SetInput("indices", indices)
3013  .CreateSymbol(symbol_name);
3014 }
3015 
3018 enum class One_hotDtype {
3019  kFloat16 = 0,
3020  kFloat32 = 1,
3021  kFloat64 = 2,
3022  kInt32 = 3,
3023  kInt64 = 4,
3024  kInt8 = 5,
3025  kUint8 = 6
3026 };
3027 
3072 inline Symbol one_hot(const std::string& symbol_name,
3073  Symbol indices,
3074  int depth,
3075  double on_value = 1,
3076  double off_value = 0,
3078  static const char *One_hotDtypeValues[] = {
3079  "float16",
3080  "float32",
3081  "float64",
3082  "int32",
3083  "int64",
3084  "int8",
3085  "uint8"
3086  };
3087  return Operator("one_hot")
3088  .SetParam("depth", depth)
3089  .SetParam("on_value", on_value)
3090  .SetParam("off_value", off_value)
3091  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
3092  .SetInput("indices", indices)
3093  .CreateSymbol(symbol_name);
3094 }
3095 
3123 inline Symbol gather_nd(const std::string& symbol_name,
3124  Symbol data,
3125  Symbol indices) {
3126  return Operator("gather_nd")
3127  .SetInput("data", data)
3128  .SetInput("indices", indices)
3129  .CreateSymbol(symbol_name);
3130 }
3131 
3168 inline Symbol scatter_nd(const std::string& symbol_name,
3169  Symbol data,
3170  Symbol indices,
3171  Shape shape) {
3172  return Operator("scatter_nd")
3173  .SetParam("shape", shape)
3174  .SetInput("data", data)
3175  .SetInput("indices", indices)
3176  .CreateSymbol(symbol_name);
3177 }
3178 
3201 inline Symbol broadcast_equal(const std::string& symbol_name,
3202  Symbol lhs,
3203  Symbol rhs) {
3204  return Operator("broadcast_equal")
3205  .SetInput("lhs", lhs)
3206  .SetInput("rhs", rhs)
3207  .CreateSymbol(symbol_name);
3208 }
3209 
3232 inline Symbol broadcast_not_equal(const std::string& symbol_name,
3233  Symbol lhs,
3234  Symbol rhs) {
3235  return Operator("broadcast_not_equal")
3236  .SetInput("lhs", lhs)
3237  .SetInput("rhs", rhs)
3238  .CreateSymbol(symbol_name);
3239 }
3240 
3263 inline Symbol broadcast_greater(const std::string& symbol_name,
3264  Symbol lhs,
3265  Symbol rhs) {
3266  return Operator("broadcast_greater")
3267  .SetInput("lhs", lhs)
3268  .SetInput("rhs", rhs)
3269  .CreateSymbol(symbol_name);
3270 }
3271 
3294 inline Symbol broadcast_greater_equal(const std::string& symbol_name,
3295  Symbol lhs,
3296  Symbol rhs) {
3297  return Operator("broadcast_greater_equal")
3298  .SetInput("lhs", lhs)
3299  .SetInput("rhs", rhs)
3300  .CreateSymbol(symbol_name);
3301 }
3302 
3325 inline Symbol broadcast_lesser(const std::string& symbol_name,
3326  Symbol lhs,
3327  Symbol rhs) {
3328  return Operator("broadcast_lesser")
3329  .SetInput("lhs", lhs)
3330  .SetInput("rhs", rhs)
3331  .CreateSymbol(symbol_name);
3332 }
3333 
3356 inline Symbol broadcast_lesser_equal(const std::string& symbol_name,
3357  Symbol lhs,
3358  Symbol rhs) {
3359  return Operator("broadcast_lesser_equal")
3360  .SetInput("lhs", lhs)
3361  .SetInput("rhs", rhs)
3362  .CreateSymbol(symbol_name);
3363 }
3364 
3400 inline Symbol where(const std::string& symbol_name,
3401  Symbol condition,
3402  Symbol x,
3403  Symbol y) {
3404  return Operator("where")
3405  .SetInput("condition", condition)
3406  .SetInput("x", x)
3407  .SetInput("y", y)
3408  .CreateSymbol(symbol_name);
3409 }
3410 
3436 inline Symbol smooth_l1(const std::string& symbol_name,
3437  Symbol data,
3438  mx_float scalar) {
3439  return Operator("smooth_l1")
3440  .SetParam("scalar", scalar)
3441  .SetInput("data", data)
3442  .CreateSymbol(symbol_name);
3443 }
3444 
3447 enum class Cast_storageStype {
3448  kCsr = 0,
3449  kDefault = 1,
3450  kRow_sparse = 2
3451 };
3452 
3498 inline Symbol cast_storage(const std::string& symbol_name,
3499  Symbol data,
3500  Cast_storageStype stype) {
3501  static const char *Cast_storageStypeValues[] = {
3502  "csr",
3503  "default",
3504  "row_sparse"
3505  };
3506  return Operator("cast_storage")
3507  .SetParam("stype", Cast_storageStypeValues[int(stype)])
3508  .SetInput("data", data)
3509  .CreateSymbol(symbol_name);
3510 }
3511 
3532 inline Symbol sin(const std::string& symbol_name,
3533  Symbol data) {
3534  return Operator("sin")
3535  .SetInput("data", data)
3536  .CreateSymbol(symbol_name);
3537 }
3538 
3556 inline Symbol cos(const std::string& symbol_name,
3557  Symbol data) {
3558  return Operator("cos")
3559  .SetInput("data", data)
3560  .CreateSymbol(symbol_name);
3561 }
3562 
3583 inline Symbol tan(const std::string& symbol_name,
3584  Symbol data) {
3585  return Operator("tan")
3586  .SetInput("data", data)
3587  .CreateSymbol(symbol_name);
3588 }
3589 
3611 inline Symbol arcsin(const std::string& symbol_name,
3612  Symbol data) {
3613  return Operator("arcsin")
3614  .SetInput("data", data)
3615  .CreateSymbol(symbol_name);
3616 }
3617 
3636 inline Symbol arccos(const std::string& symbol_name,
3637  Symbol data) {
3638  return Operator("arccos")
3639  .SetInput("data", data)
3640  .CreateSymbol(symbol_name);
3641 }
3642 
3663 inline Symbol arctan(const std::string& symbol_name,
3664  Symbol data) {
3665  return Operator("arctan")
3666  .SetInput("data", data)
3667  .CreateSymbol(symbol_name);
3668 }
3669 
3688 inline Symbol degrees(const std::string& symbol_name,
3689  Symbol data) {
3690  return Operator("degrees")
3691  .SetInput("data", data)
3692  .CreateSymbol(symbol_name);
3693 }
3694 
3713 inline Symbol radians(const std::string& symbol_name,
3714  Symbol data) {
3715  return Operator("radians")
3716  .SetInput("data", data)
3717  .CreateSymbol(symbol_name);
3718 }
3719 
3738 inline Symbol sinh(const std::string& symbol_name,
3739  Symbol data) {
3740  return Operator("sinh")
3741  .SetInput("data", data)
3742  .CreateSymbol(symbol_name);
3743 }
3744 
3760 inline Symbol cosh(const std::string& symbol_name,
3761  Symbol data) {
3762  return Operator("cosh")
3763  .SetInput("data", data)
3764  .CreateSymbol(symbol_name);
3765 }
3766 
3785 inline Symbol tanh(const std::string& symbol_name,
3786  Symbol data) {
3787  return Operator("tanh")
3788  .SetInput("data", data)
3789  .CreateSymbol(symbol_name);
3790 }
3791 
3808 inline Symbol arcsinh(const std::string& symbol_name,
3809  Symbol data) {
3810  return Operator("arcsinh")
3811  .SetInput("data", data)
3812  .CreateSymbol(symbol_name);
3813 }
3814 
3828 inline Symbol arccosh(const std::string& symbol_name,
3829  Symbol data) {
3830  return Operator("arccosh")
3831  .SetInput("data", data)
3832  .CreateSymbol(symbol_name);
3833 }
3834 
3851 inline Symbol arctanh(const std::string& symbol_name,
3852  Symbol data) {
3853  return Operator("arctanh")
3854  .SetInput("data", data)
3855  .CreateSymbol(symbol_name);
3856 }
3857 
3860 enum class PoolingPoolType {
3861  kAvg = 0,
3862  kMax = 1,
3863  kSum = 2
3864 };
3865 
3869  kFull = 0,
3870  kValid = 1
3871 };
3872 
3894 inline Symbol Pooling(const std::string& symbol_name,
3895  Symbol data,
3896  Shape kernel = Shape(),
3898  bool global_pool = false,
3899  bool cudnn_off = false,
3901  Shape stride = Shape(),
3902  Shape pad = Shape()) {
3903  static const char *PoolingPoolTypeValues[] = {
3904  "avg",
3905  "max",
3906  "sum"
3907  };
3908  static const char *PoolingPoolingConventionValues[] = {
3909  "full",
3910  "valid"
3911  };
3912  return Operator("Pooling")
3913  .SetParam("kernel", kernel)
3914  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
3915  .SetParam("global_pool", global_pool)
3916  .SetParam("cudnn_off", cudnn_off)
3917  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
3918  .SetParam("stride", stride)
3919  .SetParam("pad", pad)
3920  .SetInput("data", data)
3921  .CreateSymbol(symbol_name);
3922 }
3923 
3953 inline Symbol softmax(const std::string& symbol_name,
3954  Symbol data,
3955  int axis = -1) {
3956  return Operator("softmax")
3957  .SetParam("axis", axis)
3958  .SetInput("data", data)
3959  .CreateSymbol(symbol_name);
3960 }
3961 
3984 inline Symbol log_softmax(const std::string& symbol_name,
3985  Symbol data,
3986  int axis = -1) {
3987  return Operator("log_softmax")
3988  .SetParam("axis", axis)
3989  .SetInput("data", data)
3990  .CreateSymbol(symbol_name);
3991 }
3992 
3996  kNone = 0,
3997  kFastest = 1,
3998  kLimited_workspace = 2,
3999  kOff = 3
4000 };
4001 
4005  kNone = 0,
4006  kNCDHW = 1,
4007  kNCHW = 2,
4008  kNCW = 3,
4009  kNDHWC = 4,
4010  kNHWC = 5
4011 };
4012 
4042 inline Symbol Deconvolution(const std::string& symbol_name,
4043  Symbol data,
4044  Symbol weight,
4045  Symbol bias,
4046  Shape kernel,
4047  uint32_t num_filter,
4048  Shape stride = Shape(),
4049  Shape dilate = Shape(),
4050  Shape pad = Shape(),
4051  Shape adj = Shape(),
4052  Shape target_shape = Shape(),
4053  uint32_t num_group = 1,
4054  uint64_t workspace = 512,
4055  bool no_bias = true,
4057  bool cudnn_off = false,
4059  static const char *DeconvolutionCudnnTuneValues[] = {
4060  "None",
4061  "fastest",
4062  "limited_workspace",
4063  "off"
4064  };
4065  static const char *DeconvolutionLayoutValues[] = {
4066  "None",
4067  "NCDHW",
4068  "NCHW",
4069  "NCW",
4070  "NDHWC",
4071  "NHWC"
4072  };
4073  return Operator("Deconvolution")
4074  .SetParam("kernel", kernel)
4075  .SetParam("num_filter", num_filter)
4076  .SetParam("stride", stride)
4077  .SetParam("dilate", dilate)
4078  .SetParam("pad", pad)
4079  .SetParam("adj", adj)
4080  .SetParam("target_shape", target_shape)
4081  .SetParam("num_group", num_group)
4082  .SetParam("workspace", workspace)
4083  .SetParam("no_bias", no_bias)
4084  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
4085  .SetParam("cudnn_off", cudnn_off)
4086  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
4087  .SetInput("data", data)
4088  .SetInput("weight", weight)
4089  .SetInput("bias", bias)
4090  .CreateSymbol(symbol_name);
4091 }
4092 
4095 enum class ActivationActType {
4096  kRelu = 0,
4097  kSigmoid = 1,
4098  kSoftrelu = 2,
4099  kSoftsign = 3,
4100  kTanh = 4
4101 };
4102 
4122 inline Symbol Activation(const std::string& symbol_name,
4123  Symbol data,
4124  ActivationActType act_type) {
4125  static const char *ActivationActTypeValues[] = {
4126  "relu",
4127  "sigmoid",
4128  "softrelu",
4129  "softsign",
4130  "tanh"
4131  };
4132  return Operator("Activation")
4133  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
4134  .SetInput("data", data)
4135  .CreateSymbol(symbol_name);
4136 }
4137 
4202 inline Symbol BatchNorm(const std::string& symbol_name,
4203  Symbol data,
4204  Symbol gamma,
4205  Symbol beta,
4206  Symbol moving_mean,
4207  Symbol moving_var,
4208  double eps = 0.001,
4209  mx_float momentum = 0.9,
4210  bool fix_gamma = true,
4211  bool use_global_stats = false,
4212  bool output_mean_var = false,
4213  int axis = 1,
4214  bool cudnn_off = false) {
4215  return Operator("BatchNorm")
4216  .SetParam("eps", eps)
4217  .SetParam("momentum", momentum)
4218  .SetParam("fix_gamma", fix_gamma)
4219  .SetParam("use_global_stats", use_global_stats)
4220  .SetParam("output_mean_var", output_mean_var)
4221  .SetParam("axis", axis)
4222  .SetParam("cudnn_off", cudnn_off)
4223  .SetInput("data", data)
4224  .SetInput("gamma", gamma)
4225  .SetInput("beta", beta)
4226  .SetInput("moving_mean", moving_mean)
4227  .SetInput("moving_var", moving_var)
4228  .CreateSymbol(symbol_name);
4229 }
4230 
4234  kNone = 0,
4235  kFastest = 1,
4236  kLimited_workspace = 2,
4237  kOff = 3
4238 };
4239 
4243 enum class ConvolutionLayout {
4244  kNone = 0,
4245  kNCDHW = 1,
4246  kNCHW = 2,
4247  kNCW = 3,
4248  kNDHWC = 4,
4249  kNHWC = 5
4250 };
4251 
4349 inline Symbol Convolution(const std::string& symbol_name,
4350  Symbol data,
4351  Symbol weight,
4352  Symbol bias,
4353  Shape kernel,
4354  uint32_t num_filter,
4355  Shape stride = Shape(),
4356  Shape dilate = Shape(),
4357  Shape pad = Shape(),
4358  uint32_t num_group = 1,
4359  uint64_t workspace = 1024,
4360  bool no_bias = false,
4362  bool cudnn_off = false,
4364  static const char *ConvolutionCudnnTuneValues[] = {
4365  "None",
4366  "fastest",
4367  "limited_workspace",
4368  "off"
4369  };
4370  static const char *ConvolutionLayoutValues[] = {
4371  "None",
4372  "NCDHW",
4373  "NCHW",
4374  "NCW",
4375  "NDHWC",
4376  "NHWC"
4377  };
4378  return Operator("Convolution")
4379  .SetParam("kernel", kernel)
4380  .SetParam("num_filter", num_filter)
4381  .SetParam("stride", stride)
4382  .SetParam("dilate", dilate)
4383  .SetParam("pad", pad)
4384  .SetParam("num_group", num_group)
4385  .SetParam("workspace", workspace)
4386  .SetParam("no_bias", no_bias)
4387  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
4388  .SetParam("cudnn_off", cudnn_off)
4389  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
4390  .SetInput("data", data)
4391  .SetInput("weight", weight)
4392  .SetInput("bias", bias)
4393  .CreateSymbol(symbol_name);
4394 }
4395 
4399  kBilinear = 0,
4400  kNearest = 1
4401 };
4402 
4407  kConcat = 0,
4408  kSum = 1
4409 };
4410 
4426 inline Symbol UpSampling(const std::string& symbol_name,
4427  const std::vector<Symbol>& data,
4428  uint32_t scale,
4429  UpSamplingSampleType sample_type,
4430  int num_args,
4431  uint32_t num_filter = 0,
4433  uint64_t workspace = 512) {
4434  static const char *UpSamplingSampleTypeValues[] = {
4435  "bilinear",
4436  "nearest"
4437  };
4438  static const char *UpSamplingMultiInputModeValues[] = {
4439  "concat",
4440  "sum"
4441  };
4442  return Operator("UpSampling")
4443  .SetParam("scale", scale)
4444  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
4445  .SetParam("num_args", num_args)
4446  .SetParam("num_filter", num_filter)
4447  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
4448  .SetParam("workspace", workspace)
4449 (data)
4450  .CreateSymbol(symbol_name);
4451 }
4452 
4494 inline Symbol Concat(const std::string& symbol_name,
4495  const std::vector<Symbol>& data,
4496  int num_args,
4497  int dim = 1) {
4498  return Operator("Concat")
4499  .SetParam("num_args", num_args)
4500  .SetParam("dim", dim)
4501 (data)
4502  .CreateSymbol(symbol_name);
4503 }
4504 
4543 inline Symbol LayerNorm(const std::string& symbol_name,
4544  Symbol data,
4545  Symbol gamma,
4546  Symbol beta,
4547  int axis = -1,
4548  mx_float eps = 1e-05,
4549  bool output_mean_var = false) {
4550  return Operator("LayerNorm")
4551  .SetParam("axis", axis)
4552  .SetParam("eps", eps)
4553  .SetParam("output_mean_var", output_mean_var)
4554  .SetInput("data", data)
4555  .SetInput("gamma", gamma)
4556  .SetInput("beta", beta)
4557  .CreateSymbol(symbol_name);
4558 }
4559 
4587 inline Symbol LRN(const std::string& symbol_name,
4588  Symbol data,
4589  uint32_t nsize,
4590  mx_float alpha = 0.0001,
4591  mx_float beta = 0.75,
4592  mx_float knorm = 2) {
4593  return Operator("LRN")
4594  .SetParam("nsize", nsize)
4595  .SetParam("alpha", alpha)
4596  .SetParam("beta", beta)
4597  .SetParam("knorm", knorm)
4598  .SetInput("data", data)
4599  .CreateSymbol(symbol_name);
4600 }
4601 
4604 enum class DropoutMode {
4605  kAlways = 0,
4606  kTraining = 1
4607 };
4608 
4649 inline Symbol Dropout(const std::string& symbol_name,
4650  Symbol data,
4651  mx_float p = 0.5,
4653  Shape axes = Shape()) {
4654  static const char *DropoutModeValues[] = {
4655  "always",
4656  "training"
4657  };
4658  return Operator("Dropout")
4659  .SetParam("p", p)
4660  .SetParam("mode", DropoutModeValues[int(mode)])
4661  .SetParam("axes", axes)
4662  .SetInput("data", data)
4663  .CreateSymbol(symbol_name);
4664 }
4665 
4670  kChannel = 0,
4671  kInstance = 1
4672 };
4673 
4707 inline Symbol SoftmaxActivation(const std::string& symbol_name,
4708  Symbol data,
4710  static const char *SoftmaxActivationModeValues[] = {
4711  "channel",
4712  "instance"
4713  };
4714  return Operator("SoftmaxActivation")
4715  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
4716  .SetInput("data", data)
4717  .CreateSymbol(symbol_name);
4718 }
4719 
4757 inline Symbol FullyConnected(const std::string& symbol_name,
4758  Symbol data,
4759  Symbol weight,
4760  Symbol bias,
4761  int num_hidden,
4762  bool no_bias = false,
4763  bool flatten = true) {
4764  return Operator("FullyConnected")
4765  .SetParam("num_hidden", num_hidden)
4766  .SetParam("no_bias", no_bias)
4767  .SetParam("flatten", flatten)
4768  .SetInput("data", data)
4769  .SetInput("weight", weight)
4770  .SetInput("bias", bias)
4771  .CreateSymbol(symbol_name);
4772 }
4773 
4777 enum class PadMode {
4778  kConstant = 0,
4779  kEdge = 1,
4780  kReflect = 2
4781 };
4782 
4879 inline Symbol Pad(const std::string& symbol_name,
4880  Symbol data,
4881  PadMode mode,
4882  Shape pad_width,
4883  double constant_value = 0) {
4884  static const char *PadModeValues[] = {
4885  "constant",
4886  "edge",
4887  "reflect"
4888  };
4889  return Operator("Pad")
4890  .SetParam("mode", PadModeValues[int(mode)])
4891  .SetParam("pad_width", pad_width)
4892  .SetParam("constant_value", constant_value)
4893  .SetInput("data", data)
4894  .CreateSymbol(symbol_name);
4895 }
4896 
4899 enum class LeakyReLUActType {
4900  kElu = 0,
4901  kLeaky = 1,
4902  kPrelu = 2,
4903  kRrelu = 3
4904 };
4905 
4934 inline Symbol LeakyReLU(const std::string& symbol_name,
4935  Symbol data,
4936  Symbol gamma,
4938  mx_float slope = 0.25,
4939  mx_float lower_bound = 0.125,
4940  mx_float upper_bound = 0.334) {
4941  static const char *LeakyReLUActTypeValues[] = {
4942  "elu",
4943  "leaky",
4944  "prelu",
4945  "rrelu"
4946  };
4947  return Operator("LeakyReLU")
4948  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
4949  .SetParam("slope", slope)
4950  .SetParam("lower_bound", lower_bound)
4951  .SetParam("upper_bound", upper_bound)
4952  .SetInput("data", data)
4953  .SetInput("gamma", gamma)
4954  .CreateSymbol(symbol_name);
4955 }
4956 
4985 inline Symbol SwapAxis(const std::string& symbol_name,
4986  Symbol data,
4987  uint32_t dim1 = 0,
4988  uint32_t dim2 = 0) {
4989  return Operator("SwapAxis")
4990  .SetParam("dim1", dim1)
4991  .SetParam("dim2", dim2)
4992  .SetInput("data", data)
4993  .CreateSymbol(symbol_name);
4994 }
4995 
5053 inline Symbol BatchNorm_v1(const std::string& symbol_name,
5054  Symbol data,
5055  Symbol gamma,
5056  Symbol beta,
5057  mx_float eps = 0.001,
5058  mx_float momentum = 0.9,
5059  bool fix_gamma = true,
5060  bool use_global_stats = false,
5061  bool output_mean_var = false) {
5062  return Operator("BatchNorm_v1")
5063  .SetParam("eps", eps)
5064  .SetParam("momentum", momentum)
5065  .SetParam("fix_gamma", fix_gamma)
5066  .SetParam("use_global_stats", use_global_stats)
5067  .SetParam("output_mean_var", output_mean_var)
5068  .SetInput("data", data)
5069  .SetInput("gamma", gamma)
5070  .SetInput("beta", beta)
5071  .CreateSymbol(symbol_name);
5072 }
5073 
5111 inline Symbol softmax_cross_entropy(const std::string& symbol_name,
5112  Symbol data,
5113  Symbol label) {
5114  return Operator("softmax_cross_entropy")
5115  .SetInput("data", data)
5116  .SetInput("label", label)
5117  .CreateSymbol(symbol_name);
5118 }
5119 
5149 inline Symbol LinearRegressionOutput(const std::string& symbol_name,
5150  Symbol data,
5151  Symbol label,
5152  mx_float grad_scale = 1) {
5153  return Operator("LinearRegressionOutput")
5154  .SetParam("grad_scale", grad_scale)
5155  .SetInput("data", data)
5156  .SetInput("label", label)
5157  .CreateSymbol(symbol_name);
5158 }
5159 
5190 inline Symbol MAERegressionOutput(const std::string& symbol_name,
5191  Symbol data,
5192  Symbol label,
5193  mx_float grad_scale = 1) {
5194  return Operator("MAERegressionOutput")
5195  .SetParam("grad_scale", grad_scale)
5196  .SetInput("data", data)
5197  .SetInput("label", label)
5198  .CreateSymbol(symbol_name);
5199 }
5200 
5231 inline Symbol LogisticRegressionOutput(const std::string& symbol_name,
5232  Symbol data,
5233  Symbol label,
5234  mx_float grad_scale = 1) {
5235  return Operator("LogisticRegressionOutput")
5236  .SetParam("grad_scale", grad_scale)
5237  .SetInput("data", data)
5238  .SetInput("label", label)
5239  .CreateSymbol(symbol_name);
5240 }
5241 
5251 inline Symbol IdentityAttachKLSparseReg(const std::string& symbol_name,
5252  Symbol data,
5253  mx_float sparseness_target = 0.1,
5254  mx_float penalty = 0.001,
5255  mx_float momentum = 0.9) {
5256  return Operator("IdentityAttachKLSparseReg")
5257  .SetParam("sparseness_target", sparseness_target)
5258  .SetParam("penalty", penalty)
5259  .SetParam("momentum", momentum)
5260  .SetInput("data", data)
5261  .CreateSymbol(symbol_name);
5262 }
5263 
5292 inline Symbol signsgd_update(const std::string& symbol_name,
5293  Symbol weight,
5294  Symbol grad,
5295  mx_float lr,
5296  mx_float wd = 0,
5297  mx_float rescale_grad = 1,
5298  mx_float clip_gradient = -1) {
5299  return Operator("signsgd_update")
5300  .SetParam("lr", lr)
5301  .SetParam("wd", wd)
5302  .SetParam("rescale_grad", rescale_grad)
5303  .SetParam("clip_gradient", clip_gradient)
5304  .SetInput("weight", weight)
5305  .SetInput("grad", grad)
5306  .CreateSymbol(symbol_name);
5307 }
5308 
5343 inline Symbol signum_update(const std::string& symbol_name,
5344  Symbol weight,
5345  Symbol grad,
5346  Symbol mom,
5347  mx_float lr,
5348  mx_float momentum = 0,
5349  mx_float wd = 0,
5350  mx_float rescale_grad = 1,
5351  mx_float clip_gradient = -1,
5352  mx_float wd_lh = 0) {
5353  return Operator("signum_update")
5354  .SetParam("lr", lr)
5355  .SetParam("momentum", momentum)
5356  .SetParam("wd", wd)
5357  .SetParam("rescale_grad", rescale_grad)
5358  .SetParam("clip_gradient", clip_gradient)
5359  .SetParam("wd_lh", wd_lh)
5360  .SetInput("weight", weight)
5361  .SetInput("grad", grad)
5362  .SetInput("mom", mom)
5363  .CreateSymbol(symbol_name);
5364 }
5365 
5393 inline Symbol sgd_update(const std::string& symbol_name,
5394  Symbol weight,
5395  Symbol grad,
5396  mx_float lr,
5397  mx_float wd = 0,
5398  mx_float rescale_grad = 1,
5399  mx_float clip_gradient = -1) {
5400  return Operator("sgd_update")
5401  .SetParam("lr", lr)
5402  .SetParam("wd", wd)
5403  .SetParam("rescale_grad", rescale_grad)
5404  .SetParam("clip_gradient", clip_gradient)
5405  .SetInput("weight", weight)
5406  .SetInput("grad", grad)
5407  .CreateSymbol(symbol_name);
5408 }
5409 
5455 inline Symbol sgd_mom_update(const std::string& symbol_name,
5456  Symbol weight,
5457  Symbol grad,
5458  Symbol mom,
5459  mx_float lr,
5460  mx_float momentum = 0,
5461  mx_float wd = 0,
5462  mx_float rescale_grad = 1,
5463  mx_float clip_gradient = -1) {
5464  return Operator("sgd_mom_update")
5465  .SetParam("lr", lr)
5466  .SetParam("momentum", momentum)
5467  .SetParam("wd", wd)
5468  .SetParam("rescale_grad", rescale_grad)
5469  .SetParam("clip_gradient", clip_gradient)
5470  .SetInput("weight", weight)
5471  .SetInput("grad", grad)
5472  .SetInput("mom", mom)
5473  .CreateSymbol(symbol_name);
5474 }
5475 
5490 inline Symbol mp_sgd_update(const std::string& symbol_name,
5491  Symbol weight,
5492  Symbol grad,
5493  Symbol weight32,
5494  mx_float lr,
5495  mx_float wd = 0,
5496  mx_float rescale_grad = 1,
5497  mx_float clip_gradient = -1) {
5498  return Operator("mp_sgd_update")
5499  .SetParam("lr", lr)
5500  .SetParam("wd", wd)
5501  .SetParam("rescale_grad", rescale_grad)
5502  .SetParam("clip_gradient", clip_gradient)
5503  .SetInput("weight", weight)
5504  .SetInput("grad", grad)
5505  .SetInput("weight32", weight32)
5506  .CreateSymbol(symbol_name);
5507 }
5508 
5525 inline Symbol mp_sgd_mom_update(const std::string& symbol_name,
5526  Symbol weight,
5527  Symbol grad,
5528  Symbol mom,
5529  Symbol weight32,
5530  mx_float lr,
5531  mx_float momentum = 0,
5532  mx_float wd = 0,
5533  mx_float rescale_grad = 1,
5534  mx_float clip_gradient = -1) {
5535  return Operator("mp_sgd_mom_update")
5536  .SetParam("lr", lr)
5537  .SetParam("momentum", momentum)
5538  .SetParam("wd", wd)
5539  .SetParam("rescale_grad", rescale_grad)
5540  .SetParam("clip_gradient", clip_gradient)
5541  .SetInput("weight", weight)
5542  .SetInput("grad", grad)
5543  .SetInput("mom", mom)
5544  .SetInput("weight32", weight32)
5545  .CreateSymbol(symbol_name);
5546 }
5547 
5582 inline Symbol ftml_update(const std::string& symbol_name,
5583  Symbol weight,
5584  Symbol grad,
5585  Symbol d,
5586  Symbol v,
5587  Symbol z,
5588  mx_float lr,
5589  mx_float beta1 = 0.9,
5590  mx_float beta2 = 0.999,
5591  mx_float epsilon = 1e-08,
5592  mx_float wd = 0,
5593  mx_float rescale_grad = 1,
5594  mx_float clip_gradient = -1) {
5595  return Operator("ftml_update")
5596  .SetParam("lr", lr)
5597  .SetParam("beta1", beta1)
5598  .SetParam("beta2", beta2)
5599  .SetParam("epsilon", epsilon)
5600  .SetParam("wd", wd)
5601  .SetParam("rescale_grad", rescale_grad)
5602  .SetParam("clip_gradient", clip_gradient)
5603  .SetInput("weight", weight)
5604  .SetInput("grad", grad)
5605  .SetInput("d", d)
5606  .SetInput("v", v)
5607  .SetInput("z", z)
5608  .CreateSymbol(symbol_name);
5609 }
5610 
5658 inline Symbol adam_update(const std::string& symbol_name,
5659  Symbol weight,
5660  Symbol grad,
5661  Symbol mean,
5662  Symbol var,
5663  mx_float lr,
5664  mx_float beta1 = 0.9,
5665  mx_float beta2 = 0.999,
5666  mx_float epsilon = 1e-08,
5667  mx_float wd = 0,
5668  mx_float rescale_grad = 1,
5669  mx_float clip_gradient = -1) {
5670  return Operator("adam_update")
5671  .SetParam("lr", lr)
5672  .SetParam("beta1", beta1)
5673  .SetParam("beta2", beta2)
5674  .SetParam("epsilon", epsilon)
5675  .SetParam("wd", wd)
5676  .SetParam("rescale_grad", rescale_grad)
5677  .SetParam("clip_gradient", clip_gradient)
5678  .SetInput("weight", weight)
5679  .SetInput("grad", grad)
5680  .SetInput("mean", mean)
5681  .SetInput("var", var)
5682  .CreateSymbol(symbol_name);
5683 }
5684 
5738 inline Symbol rmsprop_update(const std::string& symbol_name,
5739  Symbol weight,
5740  Symbol grad,
5741  Symbol n,
5742  mx_float lr,
5743  mx_float gamma1 = 0.95,
5744  mx_float epsilon = 1e-08,
5745  mx_float wd = 0,
5746  mx_float rescale_grad = 1,
5747  mx_float clip_gradient = -1,
5748  mx_float clip_weights = -1) {
5749  return Operator("rmsprop_update")
5750  .SetParam("lr", lr)
5751  .SetParam("gamma1", gamma1)
5752  .SetParam("epsilon", epsilon)
5753  .SetParam("wd", wd)
5754  .SetParam("rescale_grad", rescale_grad)
5755  .SetParam("clip_gradient", clip_gradient)
5756  .SetParam("clip_weights", clip_weights)
5757  .SetInput("weight", weight)
5758  .SetInput("grad", grad)
5759  .SetInput("n", n)
5760  .CreateSymbol(symbol_name);
5761 }
5762 
5808 inline Symbol rmspropalex_update(const std::string& symbol_name,
5809  Symbol weight,
5810  Symbol grad,
5811  Symbol n,
5812  Symbol g,
5813  Symbol delta,
5814  mx_float lr,
5815  mx_float gamma1 = 0.95,
5816  mx_float gamma2 = 0.9,
5817  mx_float epsilon = 1e-08,
5818  mx_float wd = 0,
5819  mx_float rescale_grad = 1,
5820  mx_float clip_gradient = -1,
5821  mx_float clip_weights = -1) {
5822  return Operator("rmspropalex_update")
5823  .SetParam("lr", lr)
5824  .SetParam("gamma1", gamma1)
5825  .SetParam("gamma2", gamma2)
5826  .SetParam("epsilon", epsilon)
5827  .SetParam("wd", wd)
5828  .SetParam("rescale_grad", rescale_grad)
5829  .SetParam("clip_gradient", clip_gradient)
5830  .SetParam("clip_weights", clip_weights)
5831  .SetInput("weight", weight)
5832  .SetInput("grad", grad)
5833  .SetInput("n", n)
5834  .SetInput("g", g)
5835  .SetInput("delta", delta)
5836  .CreateSymbol(symbol_name);
5837 }
5838 
5878 inline Symbol ftrl_update(const std::string& symbol_name,
5879  Symbol weight,
5880  Symbol grad,
5881  Symbol z,
5882  Symbol n,
5883  mx_float lr,
5884  mx_float lamda1 = 0.01,
5885  mx_float beta = 1,
5886  mx_float wd = 0,
5887  mx_float rescale_grad = 1,
5888  mx_float clip_gradient = -1) {
5889  return Operator("ftrl_update")
5890  .SetParam("lr", lr)
5891  .SetParam("lamda1", lamda1)
5892  .SetParam("beta", beta)
5893  .SetParam("wd", wd)
5894  .SetParam("rescale_grad", rescale_grad)
5895  .SetParam("clip_gradient", clip_gradient)
5896  .SetInput("weight", weight)
5897  .SetInput("grad", grad)
5898  .SetInput("z", z)
5899  .SetInput("n", n)
5900  .CreateSymbol(symbol_name);
5901 }
5902 
5974 inline Symbol SliceChannel(const std::string& symbol_name,
5975  Symbol data,
5976  int num_outputs,
5977  int axis = 1,
5978  bool squeeze_axis = false) {
5979  return Operator("SliceChannel")
5980  .SetParam("num_outputs", num_outputs)
5981  .SetParam("axis", axis)
5982  .SetParam("squeeze_axis", squeeze_axis)
5983  .SetInput("data", data)
5984  .CreateSymbol(symbol_name);
5985 }
5986 
6037 inline Symbol InstanceNorm(const std::string& symbol_name,
6038  Symbol data,
6039  Symbol gamma,
6040  Symbol beta,
6041  mx_float eps = 0.001) {
6042  return Operator("InstanceNorm")
6043  .SetParam("eps", eps)
6044  .SetInput("data", data)
6045  .SetInput("gamma", gamma)
6046  .SetInput("beta", beta)
6047  .CreateSymbol(symbol_name);
6048 }
6049 
6054  kAffine = 0,
6055  kWarp = 1
6056 };
6057 
6068 inline Symbol GridGenerator(const std::string& symbol_name,
6069  Symbol data,
6070  GridGeneratorTransformType transform_type,
6071  Shape target_shape = Shape(0,0)) {
6072  static const char *GridGeneratorTransformTypeValues[] = {
6073  "affine",
6074  "warp"
6075  };
6076  return Operator("GridGenerator")
6077  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
6078  .SetParam("target_shape", target_shape)
6079  .SetInput("data", data)
6080  .CreateSymbol(symbol_name);
6081 }
6082 
6086  kAvg = 0,
6087  kMax = 1,
6088  kSum = 2
6089 };
6090 
6094  kFull = 0,
6095  kValid = 1
6096 };
6097 
6149 inline Symbol Pooling_v1(const std::string& symbol_name,
6150  Symbol data,
6151  Shape kernel = Shape(),
6153  bool global_pool = false,
6155  Shape stride = Shape(),
6156  Shape pad = Shape()) {
6157  static const char *Pooling_v1PoolTypeValues[] = {
6158  "avg",
6159  "max",
6160  "sum"
6161  };
6162  static const char *Pooling_v1PoolingConventionValues[] = {
6163  "full",
6164  "valid"
6165  };
6166  return Operator("Pooling_v1")
6167  .SetParam("kernel", kernel)
6168  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
6169  .SetParam("global_pool", global_pool)
6170  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
6171  .SetParam("stride", stride)
6172  .SetParam("pad", pad)
6173  .SetInput("data", data)
6174  .CreateSymbol(symbol_name);
6175 }
6176 
6179 enum class RNNMode {
6180  kGru = 0,
6181  kLstm = 1,
6182  kRnn_relu = 2,
6183  kRnn_tanh = 3
6184 };
6185 
6201 inline Symbol RNN(const std::string& symbol_name,
6202  Symbol data,
6203  Symbol parameters,
6204  Symbol state,
6205  Symbol state_cell,
6206  uint32_t state_size,
6207  uint32_t num_layers,
6208  RNNMode mode,
6209  bool bidirectional = false,
6210  mx_float p = 0,
6211  bool state_outputs = false) {
6212  static const char *RNNModeValues[] = {
6213  "gru",
6214  "lstm",
6215  "rnn_relu",
6216  "rnn_tanh"
6217  };
6218  return Operator("RNN")
6219  .SetParam("state_size", state_size)
6220  .SetParam("num_layers", num_layers)
6221  .SetParam("mode", RNNModeValues[int(mode)])
6222  .SetParam("bidirectional", bidirectional)
6223  .SetParam("p", p)
6224  .SetParam("state_outputs", state_outputs)
6225  .SetInput("data", data)
6226  .SetInput("parameters", parameters)
6227  .SetInput("state", state)
6228  .SetInput("state_cell", state_cell)
6229  .CreateSymbol(symbol_name);
6230 }
6231 
6242  kNone = 0,
6243  kFastest = 1,
6244  kLimited_workspace = 2,
6245  kOff = 3
6246 };
6247 
6252  kNone = 0,
6253  kNCDHW = 1,
6254  kNCHW = 2,
6255  kNDHWC = 3,
6256  kNHWC = 4
6257 };
6258 
6289 inline Symbol Convolution_v1(const std::string& symbol_name,
6290  Symbol data,
6291  Symbol weight,
6292  Symbol bias,
6293  Shape kernel,
6294  uint32_t num_filter,
6295  Shape stride = Shape(),
6296  Shape dilate = Shape(),
6297  Shape pad = Shape(),
6298  uint32_t num_group = 1,
6299  uint64_t workspace = 1024,
6300  bool no_bias = false,
6302  bool cudnn_off = false,
6304  static const char *Convolution_v1CudnnTuneValues[] = {
6305  "None",
6306  "fastest",
6307  "limited_workspace",
6308  "off"
6309  };
6310  static const char *Convolution_v1LayoutValues[] = {
6311  "None",
6312  "NCDHW",
6313  "NCHW",
6314  "NDHWC",
6315  "NHWC"
6316  };
6317  return Operator("Convolution_v1")
6318  .SetParam("kernel", kernel)
6319  .SetParam("num_filter", num_filter)
6320  .SetParam("stride", stride)
6321  .SetParam("dilate", dilate)
6322  .SetParam("pad", pad)
6323  .SetParam("num_group", num_group)
6324  .SetParam("workspace", workspace)
6325  .SetParam("no_bias", no_bias)
6326  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
6327  .SetParam("cudnn_off", cudnn_off)
6328  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
6329  .SetInput("data", data)
6330  .SetInput("weight", weight)
6331  .SetInput("bias", bias)
6332  .CreateSymbol(symbol_name);
6333 }
6334 
6355 inline Symbol Crop(const std::string& symbol_name,
6356  const std::vector<Symbol>& data,
6357  int num_args,
6358  Shape offset = Shape(0,0),
6359  Shape h_w = Shape(0,0),
6360  bool center_crop = false) {
6361  return Operator("Crop")
6362  .SetParam("num_args", num_args)
6363  .SetParam("offset", offset)
6364  .SetParam("h_w", h_w)
6365  .SetParam("center_crop", center_crop)
6366 (data)
6367  .CreateSymbol(symbol_name);
6368 }
6369 
6446 inline Symbol SequenceReverse(const std::string& symbol_name,
6447  Symbol data,
6448  Symbol sequence_length,
6449  bool use_sequence_length = false,
6450  int axis = 0) {
6451  return Operator("SequenceReverse")
6452  .SetParam("use_sequence_length", use_sequence_length)
6453  .SetParam("axis", axis)
6454  .SetInput("data", data)
6455  .SetInput("sequence_length", sequence_length)
6456  .CreateSymbol(symbol_name);
6457 }
6458 
6462  kAffine = 0
6463 };
6464 
6468  kBilinear = 0
6469 };
6470 
6481 inline Symbol SpatialTransformer(const std::string& symbol_name,
6482  Symbol data,
6483  Symbol loc,
6484  SpatialTransformerTransformType transform_type,
6485  SpatialTransformerSamplerType sampler_type,
6486  Shape target_shape = Shape(0,0)) {
6487  static const char *SpatialTransformerTransformTypeValues[] = {
6488  "affine"
6489  };
6490  static const char *SpatialTransformerSamplerTypeValues[] = {
6491  "bilinear"
6492  };
6493  return Operator("SpatialTransformer")
6494  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
6495  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
6496  .SetParam("target_shape", target_shape)
6497  .SetInput("data", data)
6498  .SetInput("loc", loc)
6499  .CreateSymbol(symbol_name);
6500 }
6501 
6505  kBatch = 0,
6506  kNull = 1,
6507  kValid = 2
6508 };
6509 
6605 inline Symbol SoftmaxOutput(const std::string& symbol_name,
6606  Symbol data,
6607  Symbol label,
6608  mx_float grad_scale = 1,
6609  mx_float ignore_label = -1,
6610  bool multi_output = false,
6611  bool use_ignore = false,
6612  bool preserve_shape = false,
6614  bool out_grad = false,
6615  mx_float smooth_alpha = 0) {
6616  static const char *SoftmaxOutputNormalizationValues[] = {
6617  "batch",
6618  "null",
6619  "valid"
6620  };
6621  return Operator("SoftmaxOutput")
6622  .SetParam("grad_scale", grad_scale)
6623  .SetParam("ignore_label", ignore_label)
6624  .SetParam("multi_output", multi_output)
6625  .SetParam("use_ignore", use_ignore)
6626  .SetParam("preserve_shape", preserve_shape)
6627  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
6628  .SetParam("out_grad", out_grad)
6629  .SetParam("smooth_alpha", smooth_alpha)
6630  .SetInput("data", data)
6631  .SetInput("label", label)
6632  .CreateSymbol(symbol_name);
6633 }
6634 
6638  kBatch = 0,
6639  kNull = 1,
6640  kValid = 2
6641 };
6642 
6670 inline Symbol Softmax(const std::string& symbol_name,
6671  Symbol data,
6672  mx_float grad_scale = 1,
6673  mx_float ignore_label = -1,
6674  bool multi_output = false,
6675  bool use_ignore = false,
6676  bool preserve_shape = false,
6678  bool out_grad = false,
6679  mx_float smooth_alpha = 0) {
6680  static const char *SoftmaxNormalizationValues[] = {
6681  "batch",
6682  "null",
6683  "valid"
6684  };
6685  return Operator("Softmax")
6686  .SetParam("grad_scale", grad_scale)
6687  .SetParam("ignore_label", ignore_label)
6688  .SetParam("multi_output", multi_output)
6689  .SetParam("use_ignore", use_ignore)
6690  .SetParam("preserve_shape", preserve_shape)
6691  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
6692  .SetParam("out_grad", out_grad)
6693  .SetParam("smooth_alpha", smooth_alpha)
6694  .SetInput("data", data)
6695  .CreateSymbol(symbol_name);
6696 }
6697 
6778 inline Symbol BilinearSampler(const std::string& symbol_name,
6779  Symbol data,
6780  Symbol grid) {
6781  return Operator("BilinearSampler")
6782  .SetInput("data", data)
6783  .SetInput("grid", grid)
6784  .CreateSymbol(symbol_name);
6785 }
6786 
6843 inline Symbol ROIPooling(const std::string& symbol_name,
6844  Symbol data,
6845  Symbol rois,
6846  Shape pooled_size,
6847  mx_float spatial_scale) {
6848  return Operator("ROIPooling")
6849  .SetParam("pooled_size", pooled_size)
6850  .SetParam("spatial_scale", spatial_scale)
6851  .SetInput("data", data)
6852  .SetInput("rois", rois)
6853  .CreateSymbol(symbol_name);
6854 }
6855 
6911 inline Symbol SequenceLast(const std::string& symbol_name,
6912  Symbol data,
6913  Symbol sequence_length,
6914  bool use_sequence_length = false,
6915  int axis = 0) {
6916  return Operator("SequenceLast")
6917  .SetParam("use_sequence_length", use_sequence_length)
6918  .SetParam("axis", axis)
6919  .SetInput("data", data)
6920  .SetInput("sequence_length", sequence_length)
6921  .CreateSymbol(symbol_name);
6922 }
6923 
6927  kChannel = 0,
6928  kInstance = 1,
6929  kSpatial = 2
6930 };
6931 
6994 inline Symbol L2Normalization(const std::string& symbol_name,
6995  Symbol data,
6996  mx_float eps = 1e-10,
6998  static const char *L2NormalizationModeValues[] = {
6999  "channel",
7000  "instance",
7001  "spatial"
7002  };
7003  return Operator("L2Normalization")
7004  .SetParam("eps", eps)
7005  .SetParam("mode", L2NormalizationModeValues[int(mode)])
7006  .SetInput("data", data)
7007  .CreateSymbol(symbol_name);
7008 }
7009 
7015  kBatch = 0,
7016  kNull = 1,
7017  kValid = 2
7018 };
7019 
7054 inline Symbol MakeLoss(const std::string& symbol_name,
7055  Symbol data,
7056  mx_float grad_scale = 1,
7057  mx_float valid_thresh = 0,
7059  static const char *MakeLossNormalizationValues[] = {
7060  "batch",
7061  "null",
7062  "valid"
7063  };
7064  return Operator("MakeLoss")
7065  .SetParam("grad_scale", grad_scale)
7066  .SetParam("valid_thresh", valid_thresh)
7067  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
7068  .SetInput("data", data)
7069  .CreateSymbol(symbol_name);
7070 }
7071 
7087 inline Symbol SVMOutput(const std::string& symbol_name,
7088  Symbol data,
7089  Symbol label,
7090  mx_float margin = 1,
7091  mx_float regularization_coefficient = 1,
7092  bool use_linear = false) {
7093  return Operator("SVMOutput")
7094  .SetParam("margin", margin)
7095  .SetParam("regularization_coefficient", regularization_coefficient)
7096  .SetParam("use_linear", use_linear)
7097  .SetInput("data", data)
7098  .SetInput("label", label)
7099  .CreateSymbol(symbol_name);
7100 }
7101 
7151 inline Symbol Correlation(const std::string& symbol_name,
7152  Symbol data1,
7153  Symbol data2,
7154  uint32_t kernel_size = 1,
7155  uint32_t max_displacement = 1,
7156  uint32_t stride1 = 1,
7157  uint32_t stride2 = 1,
7158  uint32_t pad_size = 0,
7159  bool is_multiply = true) {
7160  return Operator("Correlation")
7161  .SetParam("kernel_size", kernel_size)
7162  .SetParam("max_displacement", max_displacement)
7163  .SetParam("stride1", stride1)
7164  .SetParam("stride2", stride2)
7165  .SetParam("pad_size", pad_size)
7166  .SetParam("is_multiply", is_multiply)
7167  .SetInput("data1", data1)
7168  .SetInput("data2", data2)
7169  .CreateSymbol(symbol_name);
7170 }
7171 
7250 inline Symbol SequenceMask(const std::string& symbol_name,
7251  Symbol data,
7252  Symbol sequence_length,
7253  bool use_sequence_length = false,
7254  mx_float value = 0,
7255  int axis = 0) {
7256  return Operator("SequenceMask")
7257  .SetParam("use_sequence_length", use_sequence_length)
7258  .SetParam("value", value)
7259  .SetParam("axis", axis)
7260  .SetInput("data", data)
7261  .SetInput("sequence_length", sequence_length)
7262  .CreateSymbol(symbol_name);
7263 }
7264 
7273 inline Symbol choose_element_0index(const std::string& symbol_name,
7274  Symbol lhs,
7275  Symbol rhs) {
7276  return Operator("choose_element_0index")
7277  .SetInput("lhs", lhs)
7278  .SetInput("rhs", rhs)
7279  .CreateSymbol(symbol_name);
7280 }
7281 
7291 inline Symbol fill_element_0index(const std::string& symbol_name,
7292  Symbol lhs,
7293  Symbol mhs,
7294  Symbol rhs) {
7295  return Operator("fill_element_0index")
7296  .SetInput("lhs", lhs)
7297  .SetInput("mhs", mhs)
7298  .SetInput("rhs", rhs)
7299  .CreateSymbol(symbol_name);
7300 }
7301 
7341 inline Symbol khatri_rao(const std::vector<Symbol>& args) {
7342  return Operator("khatri_rao")
7343 (args)
7344  .CreateSymbol();
7345 }
7346 
7361 inline Symbol Custom(const std::vector<Symbol>& data,
7362  const std::string& op_type) {
7363  return Operator("Custom")
7364 (data)
7365  .CreateSymbol();
7366 }
7367 
7390  Symbol rhs) {
7391  return Operator("broadcast_power")
7392  .SetInput("lhs", lhs)
7393  .SetInput("rhs", rhs)
7394  .CreateSymbol();
7395 }
7396 
7421  Symbol rhs) {
7422  return Operator("broadcast_maximum")
7423  .SetInput("lhs", lhs)
7424  .SetInput("rhs", rhs)
7425  .CreateSymbol();
7426 }
7427 
7452  Symbol rhs) {
7453  return Operator("broadcast_minimum")
7454  .SetInput("lhs", lhs)
7455  .SetInput("rhs", rhs)
7456  .CreateSymbol();
7457 }
7458 
7489  Symbol rhs) {
7490  return Operator("broadcast_hypot")
7491  .SetInput("lhs", lhs)
7492  .SetInput("rhs", rhs)
7493  .CreateSymbol();
7494 }
7495 
7569 inline Symbol Reshape(Symbol data,
7570  Shape shape = Shape(),
7571  bool reverse = false,
7572  Shape target_shape = Shape(),
7573  bool keep_highest = false) {
7574  return Operator("Reshape")
7575  .SetParam("shape", shape)
7576  .SetParam("reverse", reverse)
7577  .SetParam("target_shape", target_shape)
7578  .SetParam("keep_highest", keep_highest)
7579  .SetInput("data", data)
7580  .CreateSymbol();
7581 }
7582 
7615 inline Symbol Flatten(Symbol data) {
7616  return Operator("Flatten")
7617  .SetInput("data", data)
7618  .CreateSymbol();
7619 }
7620 
7657  Shape axes = Shape()) {
7658  return Operator("transpose")
7659  .SetParam("axes", axes)
7660  .SetInput("data", data)
7661  .CreateSymbol();
7662 }
7663 
7679  int axis) {
7680  return Operator("expand_dims")
7681  .SetParam("axis", axis)
7682  .SetInput("data", data)
7683  .CreateSymbol();
7684 }
7685 
7740 inline Symbol slice(Symbol data,
7741  Shape begin,
7742  Shape end,
7743  Shape step = Shape()) {
7744  return Operator("slice")
7745  .SetParam("begin", begin)
7746  .SetParam("end", end)
7747  .SetParam("step", step)
7748  .SetInput("data", data)
7749  .CreateSymbol();
7750 }
7751 
7784  int axis,
7785  int begin,
7786  dmlc::optional<int> end) {
7787  return Operator("slice_axis")
7788  .SetParam("axis", axis)
7789  .SetParam("begin", begin)
7790  .SetParam("end", end)
7791  .SetInput("data", data)
7792  .CreateSymbol();
7793 }
7794 
7856  Symbol shape_like,
7857  Shape axes = Shape()) {
7858  return Operator("slice_like")
7859  .SetParam("axes", axes)
7860  .SetInput("data", data)
7861  .SetInput("shape_like", shape_like)
7862  .CreateSymbol();
7863 }
7864 
7898 inline Symbol clip(Symbol data,
7899  mx_float a_min,
7900  mx_float a_max) {
7901  return Operator("clip")
7902  .SetParam("a_min", a_min)
7903  .SetParam("a_max", a_max)
7904  .SetInput("data", data)
7905  .CreateSymbol();
7906 }
7907 
7941 inline Symbol repeat(Symbol data,
7942  int repeats,
7943  dmlc::optional<int> axis = dmlc::optional<int>()) {
7944  return Operator("repeat")
7945  .SetParam("repeats", repeats)
7946  .SetParam("axis", axis)
7947  .SetInput("data", data)
7948  .CreateSymbol();
7949 }
7950 
7995 inline Symbol tile(Symbol data,
7996  Shape reps) {
7997  return Operator("tile")
7998  .SetParam("reps", reps)
7999  .SetInput("data", data)
8000  .CreateSymbol();
8001 }
8002 
8025 inline Symbol reverse(Symbol data,
8026  Shape axis) {
8027  return Operator("reverse")
8028  .SetParam("axis", axis)
8029  .SetInput("data", data)
8030  .CreateSymbol();
8031 }
8032 
8055 inline Symbol stack(const std::vector<Symbol>& data,
8056  int num_args,
8057  int axis = 0) {
8058  return Operator("stack")
8059  .SetParam("num_args", num_args)
8060  .SetParam("axis", axis)
8061 (data)
8062  .CreateSymbol();
8063 }
8064 
8086 inline Symbol squeeze(const std::vector<Symbol>& data,
8087  dmlc::optional<Shape> axis = dmlc::optional<Shape>()) {
8088  return Operator("squeeze")
8089  .SetParam("axis", axis)
8090 (data)
8091  .CreateSymbol();
8092 }
8093 
8116 inline Symbol zeros_like(Symbol data) {
8117  return Operator("zeros_like")
8118  .SetInput("data", data)
8119  .CreateSymbol();
8120 }
8121 
8138 inline Symbol ones_like(Symbol data) {
8139  return Operator("ones_like")
8140  .SetInput("data", data)
8141  .CreateSymbol();
8142 }
8143 
8171  Symbol rhs) {
8172  return Operator("broadcast_add")
8173  .SetInput("lhs", lhs)
8174  .SetInput("rhs", rhs)
8175  .CreateSymbol();
8176 }
8177 
8205  Symbol rhs) {
8206  return Operator("broadcast_sub")
8207  .SetInput("lhs", lhs)
8208  .SetInput("rhs", rhs)
8209  .CreateSymbol();
8210 }
8211 
8237  Symbol rhs) {
8238  return Operator("broadcast_mul")
8239  .SetInput("lhs", lhs)
8240  .SetInput("rhs", rhs)
8241  .CreateSymbol();
8242 }
8243 
8269  Symbol rhs) {
8270  return Operator("broadcast_div")
8271  .SetInput("lhs", lhs)
8272  .SetInput("rhs", rhs)
8273  .CreateSymbol();
8274 }
8275 
8298  Symbol rhs) {
8299  return Operator("broadcast_mod")
8300  .SetInput("lhs", lhs)
8301  .SetInput("rhs", rhs)
8302  .CreateSymbol();
8303 }
8304 
8324 inline Symbol add_n(const std::vector<Symbol>& args) {
8325  return Operator("add_n")
8326 (args)
8327  .CreateSymbol();
8328 }
8329 
8360 inline Symbol argmax(Symbol data,
8361  dmlc::optional<int> axis = dmlc::optional<int>(),
8362  bool keepdims = false) {
8363  return Operator("argmax")
8364  .SetParam("axis", axis)
8365  .SetParam("keepdims", keepdims)
8366  .SetInput("data", data)
8367  .CreateSymbol();
8368 }
8369 
8400 inline Symbol argmin(Symbol data,
8401  dmlc::optional<int> axis = dmlc::optional<int>(),
8402  bool keepdims = false) {
8403  return Operator("argmin")
8404  .SetParam("axis", axis)
8405  .SetParam("keepdims", keepdims)
8406  .SetInput("data", data)
8407  .CreateSymbol();
8408 }
8409 
8432  return Operator("argmax_channel")
8433  .SetInput("data", data)
8434  .CreateSymbol();
8435 }
8436 
8481 inline Symbol pick(Symbol data,
8482  Symbol index,
8483  dmlc::optional<int> axis = dmlc::optional<int>(),
8484  bool keepdims = false) {
8485  return Operator("pick")
8486  .SetParam("axis", axis)
8487  .SetParam("keepdims", keepdims)
8488  .SetInput("data", data)
8489  .SetInput("index", index)
8490  .CreateSymbol();
8491 }
8492 
8532 inline Symbol dot(Symbol lhs,
8533  Symbol rhs,
8534  bool transpose_a = false,
8535  bool transpose_b = false) {
8536  return Operator("dot")
8537  .SetParam("transpose_a", transpose_a)
8538  .SetParam("transpose_b", transpose_b)
8539  .SetInput("lhs", lhs)
8540  .SetInput("rhs", rhs)
8541  .CreateSymbol();
8542 }
8543 
8566  Symbol rhs,
8567  bool transpose_a = false,
8568  bool transpose_b = false) {
8569  return Operator("batch_dot")
8570  .SetParam("transpose_a", transpose_a)
8571  .SetParam("transpose_b", transpose_b)
8572  .SetInput("lhs", lhs)
8573  .SetInput("rhs", rhs)
8574  .CreateSymbol();
8575 }
8576 
8594 inline Symbol relu(Symbol data) {
8595  return Operator("relu")
8596  .SetInput("data", data)
8597  .CreateSymbol();
8598 }
8599 
8614 inline Symbol sigmoid(Symbol data) {
8615  return Operator("sigmoid")
8616  .SetInput("data", data)
8617  .CreateSymbol();
8618 }
8619 
8634 inline Symbol softsign(Symbol data) {
8635  return Operator("softsign")
8636  .SetInput("data", data)
8637  .CreateSymbol();
8638 }
8639 
8672 inline Symbol BlockGrad(Symbol data) {
8673  return Operator("BlockGrad")
8674  .SetInput("data", data)
8675  .CreateSymbol();
8676 }
8677 
8706 inline Symbol make_loss(Symbol data) {
8707  return Operator("make_loss")
8708  .SetInput("data", data)
8709  .CreateSymbol();
8710 }
8711 
8719  Symbol rhs) {
8720  return Operator("reshape_like")
8721  .SetInput("lhs", lhs)
8722  .SetInput("rhs", rhs)
8723  .CreateSymbol();
8724 }
8725 
8744 inline Symbol Cast(Symbol data,
8745  CastDtype dtype) {
8746  static const char *CastDtypeValues[] = {
8747  "float16",
8748  "float32",
8749  "float64",
8750  "int32",
8751  "int64",
8752  "int8",
8753  "uint8"
8754  };
8755  return Operator("Cast")
8756  .SetParam("dtype", CastDtypeValues[int(dtype)])
8757  .SetInput("data", data)
8758  .CreateSymbol();
8759 }
8760 
8774 inline Symbol negative(Symbol data) {
8775  return Operator("negative")
8776  .SetInput("data", data)
8777  .CreateSymbol();
8778 }
8779 
8795 inline Symbol reciprocal(Symbol data) {
8796  return Operator("reciprocal")
8797  .SetInput("data", data)
8798  .CreateSymbol();
8799 }
8800 
8819 inline Symbol abs(Symbol data) {
8820  return Operator("abs")
8821  .SetInput("data", data)
8822  .CreateSymbol();
8823 }
8824 
8843 inline Symbol sign(Symbol data) {
8844  return Operator("sign")
8845  .SetInput("data", data)
8846  .CreateSymbol();
8847 }
8848 
8867 inline Symbol round(Symbol data) {
8868  return Operator("round")
8869  .SetInput("data", data)
8870  .CreateSymbol();
8871 }
8872 
8895 inline Symbol rint(Symbol data) {
8896  return Operator("rint")
8897  .SetInput("data", data)
8898  .CreateSymbol();
8899 }
8900 
8921 inline Symbol ceil(Symbol data) {
8922  return Operator("ceil")
8923  .SetInput("data", data)
8924  .CreateSymbol();
8925 }
8926 
8947 inline Symbol floor(Symbol data) {
8948  return Operator("floor")
8949  .SetInput("data", data)
8950  .CreateSymbol();
8951 }
8952 
8974 inline Symbol trunc(Symbol data) {
8975  return Operator("trunc")
8976  .SetInput("data", data)
8977  .CreateSymbol();
8978 }
8979 
8999 inline Symbol fix(Symbol data) {
9000  return Operator("fix")
9001  .SetInput("data", data)
9002  .CreateSymbol();
9003 }
9004 
9027 inline Symbol square(Symbol data) {
9028  return Operator("square")
9029  .SetInput("data", data)
9030  .CreateSymbol();
9031 }
9032 
9054 inline Symbol sqrt(Symbol data) {
9055  return Operator("sqrt")
9056  .SetInput("data", data)
9057  .CreateSymbol();
9058 }
9059 
9078 inline Symbol rsqrt(Symbol data) {
9079  return Operator("rsqrt")
9080  .SetInput("data", data)
9081  .CreateSymbol();
9082 }
9083 
9100 inline Symbol cbrt(Symbol data) {
9101  return Operator("cbrt")
9102  .SetInput("data", data)
9103  .CreateSymbol();
9104 }
9105 
9122 inline Symbol rcbrt(Symbol data) {
9123  return Operator("rcbrt")
9124  .SetInput("data", data)
9125  .CreateSymbol();
9126 }
9127 
9146 inline Symbol exp(Symbol data) {
9147  return Operator("exp")
9148  .SetInput("data", data)
9149  .CreateSymbol();
9150 }
9151 
9165 inline Symbol log(Symbol data) {
9166  return Operator("log")
9167  .SetInput("data", data)
9168  .CreateSymbol();
9169 }
9170 
9184 inline Symbol log10(Symbol data) {
9185  return Operator("log10")
9186  .SetInput("data", data)
9187  .CreateSymbol();
9188 }
9189 
9203 inline Symbol log2(Symbol data) {
9204  return Operator("log2")
9205  .SetInput("data", data)
9206  .CreateSymbol();
9207 }
9208 
9226 inline Symbol log1p(Symbol data) {
9227  return Operator("log1p")
9228  .SetInput("data", data)
9229  .CreateSymbol();
9230 }
9231 
9248 inline Symbol expm1(Symbol data) {
9249  return Operator("expm1")
9250  .SetInput("data", data)
9251  .CreateSymbol();
9252 }
9253 
9264 inline Symbol gamma(Symbol data) {
9265  return Operator("gamma")
9266  .SetInput("data", data)
9267  .CreateSymbol();
9268 }
9269 
9280 inline Symbol gammaln(Symbol data) {
9281  return Operator("gammaln")
9282  .SetInput("data", data)
9283  .CreateSymbol();
9284 }
9285 
9343 inline Symbol sum(Symbol data,
9344  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
9345  bool keepdims = false,
9346  bool exclude = false) {
9347  return Operator("sum")
9348  .SetParam("axis", axis)
9349  .SetParam("keepdims", keepdims)
9350  .SetParam("exclude", exclude)
9351  .SetInput("data", data)
9352  .CreateSymbol();
9353 }
9354 
9378 inline Symbol mean(Symbol data,
9379  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
9380  bool keepdims = false,
9381  bool exclude = false) {
9382  return Operator("mean")
9383  .SetParam("axis", axis)
9384  .SetParam("keepdims", keepdims)
9385  .SetParam("exclude", exclude)
9386  .SetInput("data", data)
9387  .CreateSymbol();
9388 }
9389 
9413 inline Symbol prod(Symbol data,
9414  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
9415  bool keepdims = false,
9416  bool exclude = false) {
9417  return Operator("prod")
9418  .SetParam("axis", axis)
9419  .SetParam("keepdims", keepdims)
9420  .SetParam("exclude", exclude)
9421  .SetInput("data", data)
9422  .CreateSymbol();
9423 }
9424 
9450 inline Symbol nansum(Symbol data,
9451  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
9452  bool keepdims = false,
9453  bool exclude = false) {
9454  return Operator("nansum")
9455  .SetParam("axis", axis)
9456  .SetParam("keepdims", keepdims)
9457  .SetParam("exclude", exclude)
9458  .SetInput("data", data)
9459  .CreateSymbol();
9460 }
9461 
9487 inline Symbol nanprod(Symbol data,
9488  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
9489  bool keepdims = false,
9490  bool exclude = false) {
9491  return Operator("nanprod")
9492  .SetParam("axis", axis)
9493  .SetParam("keepdims", keepdims)
9494  .SetParam("exclude", exclude)
9495  .SetInput("data", data)
9496  .CreateSymbol();
9497 }
9498 
9522 inline Symbol max(Symbol data,
9523  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
9524  bool keepdims = false,
9525  bool exclude = false) {
9526  return Operator("max")
9527  .SetParam("axis", axis)
9528  .SetParam("keepdims", keepdims)
9529  .SetParam("exclude", exclude)
9530  .SetInput("data", data)
9531  .CreateSymbol();
9532 }
9533 
9557 inline Symbol min(Symbol data,
9558  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
9559  bool keepdims = false,
9560  bool exclude = false) {
9561  return Operator("min")
9562  .SetParam("axis", axis)
9563  .SetParam("keepdims", keepdims)
9564  .SetParam("exclude", exclude)
9565  .SetInput("data", data)
9566  .CreateSymbol();
9567 }
9568 
9598  Shape axis = Shape(),
9599  Shape size = Shape()) {
9600  return Operator("broadcast_axis")
9601  .SetParam("axis", axis)
9602  .SetParam("size", size)
9603  .SetInput("data", data)
9604  .CreateSymbol();
9605 }
9606 
9635  Shape shape = Shape()) {
9636  return Operator("broadcast_to")
9637  .SetParam("shape", shape)
9638  .SetInput("data", data)
9639  .CreateSymbol();
9640 }
9641 
9678 inline Symbol norm(Symbol data,
9679  int ord = 2,
9680  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
9681  bool keepdims = false) {
9682  return Operator("norm")
9683  .SetParam("ord", ord)
9684  .SetParam("axis", axis)
9685  .SetParam("keepdims", keepdims)
9686  .SetInput("data", data)
9687  .CreateSymbol();
9688 }
9689 
9730 inline Symbol topk(Symbol data,
9731  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9732  int k = 1,
9733  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
9734  bool is_ascend = false) {
9735  static const char *TopkRetTypValues[] = {
9736  "both",
9737  "indices",
9738  "mask",
9739  "value"
9740  };
9741  return Operator("topk")
9742  .SetParam("axis", axis)
9743  .SetParam("k", k)
9744  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
9745  .SetParam("is_ascend", is_ascend)
9746  .SetInput("data", data)
9747  .CreateSymbol();
9748 }
9749 
9781 inline Symbol sort(Symbol data,
9782  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9783  bool is_ascend = true) {
9784  return Operator("sort")
9785  .SetParam("axis", axis)
9786  .SetParam("is_ascend", is_ascend)
9787  .SetInput("data", data)
9788  .CreateSymbol();
9789 }
9790 
9820 inline Symbol argsort(Symbol data,
9821  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9822  bool is_ascend = true) {
9823  return Operator("argsort")
9824  .SetParam("axis", axis)
9825  .SetParam("is_ascend", is_ascend)
9826  .SetInput("data", data)
9827  .CreateSymbol();
9828 }
9829 
9845  Symbol rhs) {
9846  return Operator("elemwise_add")
9847  .SetInput("lhs", lhs)
9848  .SetInput("rhs", rhs)
9849  .CreateSymbol();
9850 }
9851 
9867  Symbol rhs) {
9868  return Operator("elemwise_sub")
9869  .SetInput("lhs", lhs)
9870  .SetInput("rhs", rhs)
9871  .CreateSymbol();
9872 }
9873 
9892  Symbol rhs) {
9893  return Operator("elemwise_mul")
9894  .SetInput("lhs", lhs)
9895  .SetInput("rhs", rhs)
9896  .CreateSymbol();
9897 }
9898 
9910  Symbol rhs) {
9911  return Operator("elemwise_div")
9912  .SetInput("lhs", lhs)
9913  .SetInput("rhs", rhs)
9914  .CreateSymbol();
9915 }
9916 
9968  Symbol weight,
9969  int input_dim,
9970  int output_dim,
9972  static const char *EmbeddingDtypeValues[] = {
9973  "float16",
9974  "float32",
9975  "float64",
9976  "int32",
9977  "int64",
9978  "int8",
9979  "uint8"
9980  };
9981  return Operator("Embedding")
9982  .SetParam("input_dim", input_dim)
9983  .SetParam("output_dim", output_dim)
9984  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
9985  .SetInput("data", data)
9986  .SetInput("weight", weight)
9987  .CreateSymbol();
9988 }
9989 
10032 inline Symbol take(Symbol a,
10033  Symbol indices,
10034  int axis = 0,
10035  TakeMode mode = TakeMode::kClip) {
10036  static const char *TakeModeValues[] = {
10037  "clip",
10038  "raise",
10039  "wrap"
10040  };
10041  return Operator("take")
10042  .SetParam("axis", axis)
10043  .SetParam("mode", TakeModeValues[int(mode)])
10044  .SetInput("a", a)
10045  .SetInput("indices", indices)
10046  .CreateSymbol();
10047 }
10048 
10077  Symbol indices) {
10078  return Operator("batch_take")
10079  .SetInput("a", a)
10080  .SetInput("indices", indices)
10081  .CreateSymbol();
10082 }
10083 
10127 inline Symbol one_hot(Symbol indices,
10128  int depth,
10129  double on_value = 1,
10130  double off_value = 0,
10132  static const char *One_hotDtypeValues[] = {
10133  "float16",
10134  "float32",
10135  "float64",
10136  "int32",
10137  "int64",
10138  "int8",
10139  "uint8"
10140  };
10141  return Operator("one_hot")
10142  .SetParam("depth", depth)
10143  .SetParam("on_value", on_value)
10144  .SetParam("off_value", off_value)
10145  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
10146  .SetInput("indices", indices)
10147  .CreateSymbol();
10148 }
10149 
10177  Symbol indices) {
10178  return Operator("gather_nd")
10179  .SetInput("data", data)
10180  .SetInput("indices", indices)
10181  .CreateSymbol();
10182 }
10183 
10220  Symbol indices,
10221  Shape shape) {
10222  return Operator("scatter_nd")
10223  .SetParam("shape", shape)
10224  .SetInput("data", data)
10225  .SetInput("indices", indices)
10226  .CreateSymbol();
10227 }
10228 
10251  Symbol rhs) {
10252  return Operator("broadcast_equal")
10253  .SetInput("lhs", lhs)
10254  .SetInput("rhs", rhs)
10255  .CreateSymbol();
10256 }
10257 
10280  Symbol rhs) {
10281  return Operator("broadcast_not_equal")
10282  .SetInput("lhs", lhs)
10283  .SetInput("rhs", rhs)
10284  .CreateSymbol();
10285 }
10286 
10309  Symbol rhs) {
10310  return Operator("broadcast_greater")
10311  .SetInput("lhs", lhs)
10312  .SetInput("rhs", rhs)
10313  .CreateSymbol();
10314 }
10315 
10338  Symbol rhs) {
10339  return Operator("broadcast_greater_equal")
10340  .SetInput("lhs", lhs)
10341  .SetInput("rhs", rhs)
10342  .CreateSymbol();
10343 }
10344 
10367  Symbol rhs) {
10368  return Operator("broadcast_lesser")
10369  .SetInput("lhs", lhs)
10370  .SetInput("rhs", rhs)
10371  .CreateSymbol();
10372 }
10373 
10396  Symbol rhs) {
10397  return Operator("broadcast_lesser_equal")
10398  .SetInput("lhs", lhs)
10399  .SetInput("rhs", rhs)
10400  .CreateSymbol();
10401 }
10402 
10437 inline Symbol where(Symbol condition,
10438  Symbol x,
10439  Symbol y) {
10440  return Operator("where")
10441  .SetInput("condition", condition)
10442  .SetInput("x", x)
10443  .SetInput("y", y)
10444  .CreateSymbol();
10445 }
10446 
10472  mx_float scalar) {
10473  return Operator("smooth_l1")
10474  .SetParam("scalar", scalar)
10475  .SetInput("data", data)
10476  .CreateSymbol();
10477 }
10478 
10524  Cast_storageStype stype) {
10525  static const char *Cast_storageStypeValues[] = {
10526  "csr",
10527  "default",
10528  "row_sparse"
10529  };
10530  return Operator("cast_storage")
10531  .SetParam("stype", Cast_storageStypeValues[int(stype)])
10532  .SetInput("data", data)
10533  .CreateSymbol();
10534 }
10535 
10555 inline Symbol sin(Symbol data) {
10556  return Operator("sin")
10557  .SetInput("data", data)
10558  .CreateSymbol();
10559 }
10560 
10577 inline Symbol cos(Symbol data) {
10578  return Operator("cos")
10579  .SetInput("data", data)
10580  .CreateSymbol();
10581 }
10582 
10602 inline Symbol tan(Symbol data) {
10603  return Operator("tan")
10604  .SetInput("data", data)
10605  .CreateSymbol();
10606 }
10607 
10628 inline Symbol arcsin(Symbol data) {
10629  return Operator("arcsin")
10630  .SetInput("data", data)
10631  .CreateSymbol();
10632 }
10633 
10651 inline Symbol arccos(Symbol data) {
10652  return Operator("arccos")
10653  .SetInput("data", data)
10654  .CreateSymbol();
10655 }
10656 
10676 inline Symbol arctan(Symbol data) {
10677  return Operator("arctan")
10678  .SetInput("data", data)
10679  .CreateSymbol();
10680 }
10681 
10699 inline Symbol degrees(Symbol data) {
10700  return Operator("degrees")
10701  .SetInput("data", data)
10702  .CreateSymbol();
10703 }
10704 
10722 inline Symbol radians(Symbol data) {
10723  return Operator("radians")
10724  .SetInput("data", data)
10725  .CreateSymbol();
10726 }
10727 
10745 inline Symbol sinh(Symbol data) {
10746  return Operator("sinh")
10747  .SetInput("data", data)
10748  .CreateSymbol();
10749 }
10750 
10765 inline Symbol cosh(Symbol data) {
10766  return Operator("cosh")
10767  .SetInput("data", data)
10768  .CreateSymbol();
10769 }
10770 
10788 inline Symbol tanh(Symbol data) {
10789  return Operator("tanh")
10790  .SetInput("data", data)
10791  .CreateSymbol();
10792 }
10793 
10809 inline Symbol arcsinh(Symbol data) {
10810  return Operator("arcsinh")
10811  .SetInput("data", data)
10812  .CreateSymbol();
10813 }
10814 
10827 inline Symbol arccosh(Symbol data) {
10828  return Operator("arccosh")
10829  .SetInput("data", data)
10830  .CreateSymbol();
10831 }
10832 
10848 inline Symbol arctanh(Symbol data) {
10849  return Operator("arctanh")
10850  .SetInput("data", data)
10851  .CreateSymbol();
10852 }
10853 
10874 inline Symbol Pooling(Symbol data,
10875  Shape kernel = Shape(),
10877  bool global_pool = false,
10878  bool cudnn_off = false,
10880  Shape stride = Shape(),
10881  Shape pad = Shape()) {
10882  static const char *PoolingPoolTypeValues[] = {
10883  "avg",
10884  "max",
10885  "sum"
10886  };
10887  static const char *PoolingPoolingConventionValues[] = {
10888  "full",
10889  "valid"
10890  };
10891  return Operator("Pooling")
10892  .SetParam("kernel", kernel)
10893  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
10894  .SetParam("global_pool", global_pool)
10895  .SetParam("cudnn_off", cudnn_off)
10896  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
10897  .SetParam("stride", stride)
10898  .SetParam("pad", pad)
10899  .SetInput("data", data)
10900  .CreateSymbol();
10901 }
10902 
10931 inline Symbol softmax(Symbol data,
10932  int axis = -1) {
10933  return Operator("softmax")
10934  .SetParam("axis", axis)
10935  .SetInput("data", data)
10936  .CreateSymbol();
10937 }
10938 
10961  int axis = -1) {
10962  return Operator("log_softmax")
10963  .SetParam("axis", axis)
10964  .SetInput("data", data)
10965  .CreateSymbol();
10966 }
10967 
10997  Symbol weight,
10998  Symbol bias,
10999  Shape kernel,
11000  uint32_t num_filter,
11001  Shape stride = Shape(),
11002  Shape dilate = Shape(),
11003  Shape pad = Shape(),
11004  Shape adj = Shape(),
11005  Shape target_shape = Shape(),
11006  uint32_t num_group = 1,
11007  uint64_t workspace = 512,
11008  bool no_bias = true,
11010  bool cudnn_off = false,
11012  static const char *DeconvolutionCudnnTuneValues[] = {
11013  "None",
11014  "fastest",
11015  "limited_workspace",
11016  "off"
11017  };
11018  static const char *DeconvolutionLayoutValues[] = {
11019  "None",
11020  "NCDHW",
11021  "NCHW",
11022  "NCW",
11023  "NDHWC",
11024  "NHWC"
11025  };
11026  return Operator("Deconvolution")
11027  .SetParam("kernel", kernel)
11028  .SetParam("num_filter", num_filter)
11029  .SetParam("stride", stride)
11030  .SetParam("dilate", dilate)
11031  .SetParam("pad", pad)
11032  .SetParam("adj", adj)
11033  .SetParam("target_shape", target_shape)
11034  .SetParam("num_group", num_group)
11035  .SetParam("workspace", workspace)
11036  .SetParam("no_bias", no_bias)
11037  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
11038  .SetParam("cudnn_off", cudnn_off)
11039  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
11040  .SetInput("data", data)
11041  .SetInput("weight", weight)
11042  .SetInput("bias", bias)
11043  .CreateSymbol();
11044 }
11045 
11065  ActivationActType act_type) {
11066  static const char *ActivationActTypeValues[] = {
11067  "relu",
11068  "sigmoid",
11069  "softrelu",
11070  "softsign",
11071  "tanh"
11072  };
11073  return Operator("Activation")
11074  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
11075  .SetInput("data", data)
11076  .CreateSymbol();
11077 }
11078 
11143  Symbol gamma,
11144  Symbol beta,
11145  Symbol moving_mean,
11146  Symbol moving_var,
11147  double eps = 0.001,
11148  mx_float momentum = 0.9,
11149  bool fix_gamma = true,
11150  bool use_global_stats = false,
11151  bool output_mean_var = false,
11152  int axis = 1,
11153  bool cudnn_off = false) {
11154  return Operator("BatchNorm")
11155  .SetParam("eps", eps)
11156  .SetParam("momentum", momentum)
11157  .SetParam("fix_gamma", fix_gamma)
11158  .SetParam("use_global_stats", use_global_stats)
11159  .SetParam("output_mean_var", output_mean_var)
11160  .SetParam("axis", axis)
11161  .SetParam("cudnn_off", cudnn_off)
11162  .SetInput("data", data)
11163  .SetInput("gamma", gamma)
11164  .SetInput("beta", beta)
11165  .SetInput("moving_mean", moving_mean)
11166  .SetInput("moving_var", moving_var)
11167  .CreateSymbol();
11168 }
11169 
11267  Symbol weight,
11268  Symbol bias,
11269  Shape kernel,
11270  uint32_t num_filter,
11271  Shape stride = Shape(),
11272  Shape dilate = Shape(),
11273  Shape pad = Shape(),
11274  uint32_t num_group = 1,
11275  uint64_t workspace = 1024,
11276  bool no_bias = false,
11278  bool cudnn_off = false,
11280  static const char *ConvolutionCudnnTuneValues[] = {
11281  "None",
11282  "fastest",
11283  "limited_workspace",
11284  "off"
11285  };
11286  static const char *ConvolutionLayoutValues[] = {
11287  "None",
11288  "NCDHW",
11289  "NCHW",
11290  "NCW",
11291  "NDHWC",
11292  "NHWC"
11293  };
11294  return Operator("Convolution")
11295  .SetParam("kernel", kernel)
11296  .SetParam("num_filter", num_filter)
11297  .SetParam("stride", stride)
11298  .SetParam("dilate", dilate)
11299  .SetParam("pad", pad)
11300  .SetParam("num_group", num_group)
11301  .SetParam("workspace", workspace)
11302  .SetParam("no_bias", no_bias)
11303  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
11304  .SetParam("cudnn_off", cudnn_off)
11305  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
11306  .SetInput("data", data)
11307  .SetInput("weight", weight)
11308  .SetInput("bias", bias)
11309  .CreateSymbol();
11310 }
11311 
11326 inline Symbol UpSampling(const std::vector<Symbol>& data,
11327  uint32_t scale,
11328  UpSamplingSampleType sample_type,
11329  int num_args,
11330  uint32_t num_filter = 0,
11332  uint64_t workspace = 512) {
11333  static const char *UpSamplingSampleTypeValues[] = {
11334  "bilinear",
11335  "nearest"
11336  };
11337  static const char *UpSamplingMultiInputModeValues[] = {
11338  "concat",
11339  "sum"
11340  };
11341  return Operator("UpSampling")
11342  .SetParam("scale", scale)
11343  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
11344  .SetParam("num_args", num_args)
11345  .SetParam("num_filter", num_filter)
11346  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
11347  .SetParam("workspace", workspace)
11348 (data)
11349  .CreateSymbol();
11350 }
11351 
11392 inline Symbol Concat(const std::vector<Symbol>& data,
11393  int num_args,
11394  int dim = 1) {
11395  return Operator("Concat")
11396  .SetParam("num_args", num_args)
11397  .SetParam("dim", dim)
11398 (data)
11399  .CreateSymbol();
11400 }
11401 
11440  Symbol gamma,
11441  Symbol beta,
11442  int axis = -1,
11443  mx_float eps = 1e-05,
11444  bool output_mean_var = false) {
11445  return Operator("LayerNorm")
11446  .SetParam("axis", axis)
11447  .SetParam("eps", eps)
11448  .SetParam("output_mean_var", output_mean_var)
11449  .SetInput("data", data)
11450  .SetInput("gamma", gamma)
11451  .SetInput("beta", beta)
11452  .CreateSymbol();
11453 }
11454 
11481 inline Symbol LRN(Symbol data,
11482  uint32_t nsize,
11483  mx_float alpha = 0.0001,
11484  mx_float beta = 0.75,
11485  mx_float knorm = 2) {
11486  return Operator("LRN")
11487  .SetParam("nsize", nsize)
11488  .SetParam("alpha", alpha)
11489  .SetParam("beta", beta)
11490  .SetParam("knorm", knorm)
11491  .SetInput("data", data)
11492  .CreateSymbol();
11493 }
11494 
11534 inline Symbol Dropout(Symbol data,
11535  mx_float p = 0.5,
11537  Shape axes = Shape()) {
11538  static const char *DropoutModeValues[] = {
11539  "always",
11540  "training"
11541  };
11542  return Operator("Dropout")
11543  .SetParam("p", p)
11544  .SetParam("mode", DropoutModeValues[int(mode)])
11545  .SetParam("axes", axes)
11546  .SetInput("data", data)
11547  .CreateSymbol();
11548 }
11549 
11584  static const char *SoftmaxActivationModeValues[] = {
11585  "channel",
11586  "instance"
11587  };
11588  return Operator("SoftmaxActivation")
11589  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
11590  .SetInput("data", data)
11591  .CreateSymbol();
11592 }
11593 
11631  Symbol weight,
11632  Symbol bias,
11633  int num_hidden,
11634  bool no_bias = false,
11635  bool flatten = true) {
11636  return Operator("FullyConnected")
11637  .SetParam("num_hidden", num_hidden)
11638  .SetParam("no_bias", no_bias)
11639  .SetParam("flatten", flatten)
11640  .SetInput("data", data)
11641  .SetInput("weight", weight)
11642  .SetInput("bias", bias)
11643  .CreateSymbol();
11644 }
11645 
11741 inline Symbol Pad(Symbol data,
11742  PadMode mode,
11743  Shape pad_width,
11744  double constant_value = 0) {
11745  static const char *PadModeValues[] = {
11746  "constant",
11747  "edge",
11748  "reflect"
11749  };
11750  return Operator("Pad")
11751  .SetParam("mode", PadModeValues[int(mode)])
11752  .SetParam("pad_width", pad_width)
11753  .SetParam("constant_value", constant_value)
11754  .SetInput("data", data)
11755  .CreateSymbol();
11756 }
11757 
11786  Symbol gamma,
11788  mx_float slope = 0.25,
11789  mx_float lower_bound = 0.125,
11790  mx_float upper_bound = 0.334) {
11791  static const char *LeakyReLUActTypeValues[] = {
11792  "elu",
11793  "leaky",
11794  "prelu",
11795  "rrelu"
11796  };
11797  return Operator("LeakyReLU")
11798  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
11799  .SetParam("slope", slope)
11800  .SetParam("lower_bound", lower_bound)
11801  .SetParam("upper_bound", upper_bound)
11802  .SetInput("data", data)
11803  .SetInput("gamma", gamma)
11804  .CreateSymbol();
11805 }
11806 
11834 inline Symbol SwapAxis(Symbol data,
11835  uint32_t dim1 = 0,
11836  uint32_t dim2 = 0) {
11837  return Operator("SwapAxis")
11838  .SetParam("dim1", dim1)
11839  .SetParam("dim2", dim2)
11840  .SetInput("data", data)
11841  .CreateSymbol();
11842 }
11843 
11901  Symbol gamma,
11902  Symbol beta,
11903  mx_float eps = 0.001,
11904  mx_float momentum = 0.9,
11905  bool fix_gamma = true,
11906  bool use_global_stats = false,
11907  bool output_mean_var = false) {
11908  return Operator("BatchNorm_v1")
11909  .SetParam("eps", eps)
11910  .SetParam("momentum", momentum)
11911  .SetParam("fix_gamma", fix_gamma)
11912  .SetParam("use_global_stats", use_global_stats)
11913  .SetParam("output_mean_var", output_mean_var)
11914  .SetInput("data", data)
11915  .SetInput("gamma", gamma)
11916  .SetInput("beta", beta)
11917  .CreateSymbol();
11918 }
11919 
11957  Symbol label) {
11958  return Operator("softmax_cross_entropy")
11959  .SetInput("data", data)
11960  .SetInput("label", label)
11961  .CreateSymbol();
11962 }
11963 
11993  Symbol label,
11994  mx_float grad_scale = 1) {
11995  return Operator("LinearRegressionOutput")
11996  .SetParam("grad_scale", grad_scale)
11997  .SetInput("data", data)
11998  .SetInput("label", label)
11999  .CreateSymbol();
12000 }
12001 
12032  Symbol label,
12033  mx_float grad_scale = 1) {
12034  return Operator("MAERegressionOutput")
12035  .SetParam("grad_scale", grad_scale)
12036  .SetInput("data", data)
12037  .SetInput("label", label)
12038  .CreateSymbol();
12039 }
12040 
12071  Symbol label,
12072  mx_float grad_scale = 1) {
12073  return Operator("LogisticRegressionOutput")
12074  .SetParam("grad_scale", grad_scale)
12075  .SetInput("data", data)
12076  .SetInput("label", label)
12077  .CreateSymbol();
12078 }
12079 
12089  mx_float sparseness_target = 0.1,
12090  mx_float penalty = 0.001,
12091  mx_float momentum = 0.9) {
12092  return Operator("IdentityAttachKLSparseReg")
12093  .SetParam("sparseness_target", sparseness_target)
12094  .SetParam("penalty", penalty)
12095  .SetParam("momentum", momentum)
12096  .SetInput("data", data)
12097  .CreateSymbol();
12098 }
12099 
12128  Symbol grad,
12129  mx_float lr,
12130  mx_float wd = 0,
12131  mx_float rescale_grad = 1,
12132  mx_float clip_gradient = -1) {
12133  return Operator("signsgd_update")
12134  .SetParam("lr", lr)
12135  .SetParam("wd", wd)
12136  .SetParam("rescale_grad", rescale_grad)
12137  .SetParam("clip_gradient", clip_gradient)
12138  .SetInput("weight", weight)
12139  .SetInput("grad", grad)
12140  .CreateSymbol();
12141 }
12142 
12177  Symbol grad,
12178  Symbol mom,
12179  mx_float lr,
12180  mx_float momentum = 0,
12181  mx_float wd = 0,
12182  mx_float rescale_grad = 1,
12183  mx_float clip_gradient = -1,
12184  mx_float wd_lh = 0) {
12185  return Operator("signum_update")
12186  .SetParam("lr", lr)
12187  .SetParam("momentum", momentum)
12188  .SetParam("wd", wd)
12189  .SetParam("rescale_grad", rescale_grad)
12190  .SetParam("clip_gradient", clip_gradient)
12191  .SetParam("wd_lh", wd_lh)
12192  .SetInput("weight", weight)
12193  .SetInput("grad", grad)
12194  .SetInput("mom", mom)
12195  .CreateSymbol();
12196 }
12197 
12224 inline Symbol sgd_update(Symbol weight,
12225  Symbol grad,
12226  mx_float lr,
12227  mx_float wd = 0,
12228  mx_float rescale_grad = 1,
12229  mx_float clip_gradient = -1) {
12230  return Operator("sgd_update")
12231  .SetParam("lr", lr)
12232  .SetParam("wd", wd)
12233  .SetParam("rescale_grad", rescale_grad)
12234  .SetParam("clip_gradient", clip_gradient)
12235  .SetInput("weight", weight)
12236  .SetInput("grad", grad)
12237  .CreateSymbol();
12238 }
12239 
12285  Symbol grad,
12286  Symbol mom,
12287  mx_float lr,
12288  mx_float momentum = 0,
12289  mx_float wd = 0,
12290  mx_float rescale_grad = 1,
12291  mx_float clip_gradient = -1) {
12292  return Operator("sgd_mom_update")
12293  .SetParam("lr", lr)
12294  .SetParam("momentum", momentum)
12295  .SetParam("wd", wd)
12296  .SetParam("rescale_grad", rescale_grad)
12297  .SetParam("clip_gradient", clip_gradient)
12298  .SetInput("weight", weight)
12299  .SetInput("grad", grad)
12300  .SetInput("mom", mom)
12301  .CreateSymbol();
12302 }
12303 
12318  Symbol grad,
12319  Symbol weight32,
12320  mx_float lr,
12321  mx_float wd = 0,
12322  mx_float rescale_grad = 1,
12323  mx_float clip_gradient = -1) {
12324  return Operator("mp_sgd_update")
12325  .SetParam("lr", lr)
12326  .SetParam("wd", wd)
12327  .SetParam("rescale_grad", rescale_grad)
12328  .SetParam("clip_gradient", clip_gradient)
12329  .SetInput("weight", weight)
12330  .SetInput("grad", grad)
12331  .SetInput("weight32", weight32)
12332  .CreateSymbol();
12333 }
12334 
12351  Symbol grad,
12352  Symbol mom,
12353  Symbol weight32,
12354  mx_float lr,
12355  mx_float momentum = 0,
12356  mx_float wd = 0,
12357  mx_float rescale_grad = 1,
12358  mx_float clip_gradient = -1) {
12359  return Operator("mp_sgd_mom_update")
12360  .SetParam("lr", lr)
12361  .SetParam("momentum", momentum)
12362  .SetParam("wd", wd)
12363  .SetParam("rescale_grad", rescale_grad)
12364  .SetParam("clip_gradient", clip_gradient)
12365  .SetInput("weight", weight)
12366  .SetInput("grad", grad)
12367  .SetInput("mom", mom)
12368  .SetInput("weight32", weight32)
12369  .CreateSymbol();
12370 }
12371 
12405 inline Symbol ftml_update(Symbol weight,
12406  Symbol grad,
12407  Symbol d,
12408  Symbol v,
12409  Symbol z,
12410  mx_float lr,
12411  mx_float beta1 = 0.9,
12412  mx_float beta2 = 0.999,
12413  mx_float epsilon = 1e-08,
12414  mx_float wd = 0,
12415  mx_float rescale_grad = 1,
12416  mx_float clip_gradient = -1) {
12417  return Operator("ftml_update")
12418  .SetParam("lr", lr)
12419  .SetParam("beta1", beta1)
12420  .SetParam("beta2", beta2)
12421  .SetParam("epsilon", epsilon)
12422  .SetParam("wd", wd)
12423  .SetParam("rescale_grad", rescale_grad)
12424  .SetParam("clip_gradient", clip_gradient)
12425  .SetInput("weight", weight)
12426  .SetInput("grad", grad)
12427  .SetInput("d", d)
12428  .SetInput("v", v)
12429  .SetInput("z", z)
12430  .CreateSymbol();
12431 }
12432 
12479 inline Symbol adam_update(Symbol weight,
12480  Symbol grad,
12481  Symbol mean,
12482  Symbol var,
12483  mx_float lr,
12484  mx_float beta1 = 0.9,
12485  mx_float beta2 = 0.999,
12486  mx_float epsilon = 1e-08,
12487  mx_float wd = 0,
12488  mx_float rescale_grad = 1,
12489  mx_float clip_gradient = -1) {
12490  return Operator("adam_update")
12491  .SetParam("lr", lr)
12492  .SetParam("beta1", beta1)
12493  .SetParam("beta2", beta2)
12494  .SetParam("epsilon", epsilon)
12495  .SetParam("wd", wd)
12496  .SetParam("rescale_grad", rescale_grad)
12497  .SetParam("clip_gradient", clip_gradient)
12498  .SetInput("weight", weight)
12499  .SetInput("grad", grad)
12500  .SetInput("mean", mean)
12501  .SetInput("var", var)
12502  .CreateSymbol();
12503 }
12504 
12558  Symbol grad,
12559  Symbol n,
12560  mx_float lr,
12561  mx_float gamma1 = 0.95,
12562  mx_float epsilon = 1e-08,
12563  mx_float wd = 0,
12564  mx_float rescale_grad = 1,
12565  mx_float clip_gradient = -1,
12566  mx_float clip_weights = -1) {
12567  return Operator("rmsprop_update")
12568  .SetParam("lr", lr)
12569  .SetParam("gamma1", gamma1)
12570  .SetParam("epsilon", epsilon)
12571  .SetParam("wd", wd)
12572  .SetParam("rescale_grad", rescale_grad)
12573  .SetParam("clip_gradient", clip_gradient)
12574  .SetParam("clip_weights", clip_weights)
12575  .SetInput("weight", weight)
12576  .SetInput("grad", grad)
12577  .SetInput("n", n)
12578  .CreateSymbol();
12579 }
12580 
12626  Symbol grad,
12627  Symbol n,
12628  Symbol g,
12629  Symbol delta,
12630  mx_float lr,
12631  mx_float gamma1 = 0.95,
12632  mx_float gamma2 = 0.9,
12633  mx_float epsilon = 1e-08,
12634  mx_float wd = 0,
12635  mx_float rescale_grad = 1,
12636  mx_float clip_gradient = -1,
12637  mx_float clip_weights = -1) {
12638  return Operator("rmspropalex_update")
12639  .SetParam("lr", lr)
12640  .SetParam("gamma1", gamma1)
12641  .SetParam("gamma2", gamma2)
12642  .SetParam("epsilon", epsilon)
12643  .SetParam("wd", wd)
12644  .SetParam("rescale_grad", rescale_grad)
12645  .SetParam("clip_gradient", clip_gradient)
12646  .SetParam("clip_weights", clip_weights)
12647  .SetInput("weight", weight)
12648  .SetInput("grad", grad)
12649  .SetInput("n", n)
12650  .SetInput("g", g)
12651  .SetInput("delta", delta)
12652  .CreateSymbol();
12653 }
12654 
12693 inline Symbol ftrl_update(Symbol weight,
12694  Symbol grad,
12695  Symbol z,
12696  Symbol n,
12697  mx_float lr,
12698  mx_float lamda1 = 0.01,
12699  mx_float beta = 1,
12700  mx_float wd = 0,
12701  mx_float rescale_grad = 1,
12702  mx_float clip_gradient = -1) {
12703  return Operator("ftrl_update")
12704  .SetParam("lr", lr)
12705  .SetParam("lamda1", lamda1)
12706  .SetParam("beta", beta)
12707  .SetParam("wd", wd)
12708  .SetParam("rescale_grad", rescale_grad)
12709  .SetParam("clip_gradient", clip_gradient)
12710  .SetInput("weight", weight)
12711  .SetInput("grad", grad)
12712  .SetInput("z", z)
12713  .SetInput("n", n)
12714  .CreateSymbol();
12715 }
12716 
12788  int num_outputs,
12789  int axis = 1,
12790  bool squeeze_axis = false) {
12791  return Operator("SliceChannel")
12792  .SetParam("num_outputs", num_outputs)
12793  .SetParam("axis", axis)
12794  .SetParam("squeeze_axis", squeeze_axis)
12795  .SetInput("data", data)
12796  .CreateSymbol();
12797 }
12798 
12849  Symbol gamma,
12850  Symbol beta,
12851  mx_float eps = 0.001) {
12852  return Operator("InstanceNorm")
12853  .SetParam("eps", eps)
12854  .SetInput("data", data)
12855  .SetInput("gamma", gamma)
12856  .SetInput("beta", beta)
12857  .CreateSymbol();
12858 }
12859 
12870  GridGeneratorTransformType transform_type,
12871  Shape target_shape = Shape(0,0)) {
12872  static const char *GridGeneratorTransformTypeValues[] = {
12873  "affine",
12874  "warp"
12875  };
12876  return Operator("GridGenerator")
12877  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
12878  .SetParam("target_shape", target_shape)
12879  .SetInput("data", data)
12880  .CreateSymbol();
12881 }
12882 
12934  Shape kernel = Shape(),
12936  bool global_pool = false,
12938  Shape stride = Shape(),
12939  Shape pad = Shape()) {
12940  static const char *Pooling_v1PoolTypeValues[] = {
12941  "avg",
12942  "max",
12943  "sum"
12944  };
12945  static const char *Pooling_v1PoolingConventionValues[] = {
12946  "full",
12947  "valid"
12948  };
12949  return Operator("Pooling_v1")
12950  .SetParam("kernel", kernel)
12951  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
12952  .SetParam("global_pool", global_pool)
12953  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
12954  .SetParam("stride", stride)
12955  .SetParam("pad", pad)
12956  .SetInput("data", data)
12957  .CreateSymbol();
12958 }
12959 
12974 inline Symbol RNN(Symbol data,
12975  Symbol parameters,
12976  Symbol state,
12977  Symbol state_cell,
12978  uint32_t state_size,
12979  uint32_t num_layers,
12980  RNNMode mode,
12981  bool bidirectional = false,
12982  mx_float p = 0,
12983  bool state_outputs = false) {
12984  static const char *RNNModeValues[] = {
12985  "gru",
12986  "lstm",
12987  "rnn_relu",
12988  "rnn_tanh"
12989  };
12990  return Operator("RNN")
12991  .SetParam("state_size", state_size)
12992  .SetParam("num_layers", num_layers)
12993  .SetParam("mode", RNNModeValues[int(mode)])
12994  .SetParam("bidirectional", bidirectional)
12995  .SetParam("p", p)
12996  .SetParam("state_outputs", state_outputs)
12997  .SetInput("data", data)
12998  .SetInput("parameters", parameters)
12999  .SetInput("state", state)
13000  .SetInput("state_cell", state_cell)
13001  .CreateSymbol();
13002 }
13003 
13034  Symbol weight,
13035  Symbol bias,
13036  Shape kernel,
13037  uint32_t num_filter,
13038  Shape stride = Shape(),
13039  Shape dilate = Shape(),
13040  Shape pad = Shape(),
13041  uint32_t num_group = 1,
13042  uint64_t workspace = 1024,
13043  bool no_bias = false,
13045  bool cudnn_off = false,
13047  static const char *Convolution_v1CudnnTuneValues[] = {
13048  "None",
13049  "fastest",
13050  "limited_workspace",
13051  "off"
13052  };
13053  static const char *Convolution_v1LayoutValues[] = {
13054  "None",
13055  "NCDHW",
13056  "NCHW",
13057  "NDHWC",
13058  "NHWC"
13059  };
13060  return Operator("Convolution_v1")
13061  .SetParam("kernel", kernel)
13062  .SetParam("num_filter", num_filter)
13063  .SetParam("stride", stride)
13064  .SetParam("dilate", dilate)
13065  .SetParam("pad", pad)
13066  .SetParam("num_group", num_group)
13067  .SetParam("workspace", workspace)
13068  .SetParam("no_bias", no_bias)
13069  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
13070  .SetParam("cudnn_off", cudnn_off)
13071  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
13072  .SetInput("data", data)
13073  .SetInput("weight", weight)
13074  .SetInput("bias", bias)
13075  .CreateSymbol();
13076 }
13077 
13097 inline Symbol Crop(const std::vector<Symbol>& data,
13098  int num_args,
13099  Shape offset = Shape(0,0),
13100  Shape h_w = Shape(0,0),
13101  bool center_crop = false) {
13102  return Operator("Crop")
13103  .SetParam("num_args", num_args)
13104  .SetParam("offset", offset)
13105  .SetParam("h_w", h_w)
13106  .SetParam("center_crop", center_crop)
13107 (data)
13108  .CreateSymbol();
13109 }
13110 
13187  Symbol sequence_length,
13188  bool use_sequence_length = false,
13189  int axis = 0) {
13190  return Operator("SequenceReverse")
13191  .SetParam("use_sequence_length", use_sequence_length)
13192  .SetParam("axis", axis)
13193  .SetInput("data", data)
13194  .SetInput("sequence_length", sequence_length)
13195  .CreateSymbol();
13196 }
13197 
13208  Symbol loc,
13209  SpatialTransformerTransformType transform_type,
13210  SpatialTransformerSamplerType sampler_type,
13211  Shape target_shape = Shape(0,0)) {
13212  static const char *SpatialTransformerTransformTypeValues[] = {
13213  "affine"
13214  };
13215  static const char *SpatialTransformerSamplerTypeValues[] = {
13216  "bilinear"
13217  };
13218  return Operator("SpatialTransformer")
13219  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
13220  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
13221  .SetParam("target_shape", target_shape)
13222  .SetInput("data", data)
13223  .SetInput("loc", loc)
13224  .CreateSymbol();
13225 }
13226 
13322  Symbol label,
13323  mx_float grad_scale = 1,
13324  mx_float ignore_label = -1,
13325  bool multi_output = false,
13326  bool use_ignore = false,
13327  bool preserve_shape = false,
13329  bool out_grad = false,
13330  mx_float smooth_alpha = 0) {
13331  static const char *SoftmaxOutputNormalizationValues[] = {
13332  "batch",
13333  "null",
13334  "valid"
13335  };
13336  return Operator("SoftmaxOutput")
13337  .SetParam("grad_scale", grad_scale)
13338  .SetParam("ignore_label", ignore_label)
13339  .SetParam("multi_output", multi_output)
13340  .SetParam("use_ignore", use_ignore)
13341  .SetParam("preserve_shape", preserve_shape)
13342  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
13343  .SetParam("out_grad", out_grad)
13344  .SetParam("smooth_alpha", smooth_alpha)
13345  .SetInput("data", data)
13346  .SetInput("label", label)
13347  .CreateSymbol();
13348 }
13349 
13376 inline Symbol Softmax(Symbol data,
13377  mx_float grad_scale = 1,
13378  mx_float ignore_label = -1,
13379  bool multi_output = false,
13380  bool use_ignore = false,
13381  bool preserve_shape = false,
13383  bool out_grad = false,
13384  mx_float smooth_alpha = 0) {
13385  static const char *SoftmaxNormalizationValues[] = {
13386  "batch",
13387  "null",
13388  "valid"
13389  };
13390  return Operator("Softmax")
13391  .SetParam("grad_scale", grad_scale)
13392  .SetParam("ignore_label", ignore_label)
13393  .SetParam("multi_output", multi_output)
13394  .SetParam("use_ignore", use_ignore)
13395  .SetParam("preserve_shape", preserve_shape)
13396  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
13397  .SetParam("out_grad", out_grad)
13398  .SetParam("smooth_alpha", smooth_alpha)
13399  .SetInput("data", data)
13400  .CreateSymbol();
13401 }
13402 
13483  Symbol grid) {
13484  return Operator("BilinearSampler")
13485  .SetInput("data", data)
13486  .SetInput("grid", grid)
13487  .CreateSymbol();
13488 }
13489 
13546  Symbol rois,
13547  Shape pooled_size,
13548  mx_float spatial_scale) {
13549  return Operator("ROIPooling")
13550  .SetParam("pooled_size", pooled_size)
13551  .SetParam("spatial_scale", spatial_scale)
13552  .SetInput("data", data)
13553  .SetInput("rois", rois)
13554  .CreateSymbol();
13555 }
13556 
13612  Symbol sequence_length,
13613  bool use_sequence_length = false,
13614  int axis = 0) {
13615  return Operator("SequenceLast")
13616  .SetParam("use_sequence_length", use_sequence_length)
13617  .SetParam("axis", axis)
13618  .SetInput("data", data)
13619  .SetInput("sequence_length", sequence_length)
13620  .CreateSymbol();
13621 }
13622 
13685  mx_float eps = 1e-10,
13687  static const char *L2NormalizationModeValues[] = {
13688  "channel",
13689  "instance",
13690  "spatial"
13691  };
13692  return Operator("L2Normalization")
13693  .SetParam("eps", eps)
13694  .SetParam("mode", L2NormalizationModeValues[int(mode)])
13695  .SetInput("data", data)
13696  .CreateSymbol();
13697 }
13698 
13732 inline Symbol MakeLoss(Symbol data,
13733  mx_float grad_scale = 1,
13734  mx_float valid_thresh = 0,
13736  static const char *MakeLossNormalizationValues[] = {
13737  "batch",
13738  "null",
13739  "valid"
13740  };
13741  return Operator("MakeLoss")
13742  .SetParam("grad_scale", grad_scale)
13743  .SetParam("valid_thresh", valid_thresh)
13744  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
13745  .SetInput("data", data)
13746  .CreateSymbol();
13747 }
13748 
13764  Symbol label,
13765  mx_float margin = 1,
13766  mx_float regularization_coefficient = 1,
13767  bool use_linear = false) {
13768  return Operator("SVMOutput")
13769  .SetParam("margin", margin)
13770  .SetParam("regularization_coefficient", regularization_coefficient)
13771  .SetParam("use_linear", use_linear)
13772  .SetInput("data", data)
13773  .SetInput("label", label)
13774  .CreateSymbol();
13775 }
13776 
13826  Symbol data2,
13827  uint32_t kernel_size = 1,
13828  uint32_t max_displacement = 1,
13829  uint32_t stride1 = 1,
13830  uint32_t stride2 = 1,
13831  uint32_t pad_size = 0,
13832  bool is_multiply = true) {
13833  return Operator("Correlation")
13834  .SetParam("kernel_size", kernel_size)
13835  .SetParam("max_displacement", max_displacement)
13836  .SetParam("stride1", stride1)
13837  .SetParam("stride2", stride2)
13838  .SetParam("pad_size", pad_size)
13839  .SetParam("is_multiply", is_multiply)
13840  .SetInput("data1", data1)
13841  .SetInput("data2", data2)
13842  .CreateSymbol();
13843 }
13844 
13923  Symbol sequence_length,
13924  bool use_sequence_length = false,
13925  mx_float value = 0,
13926  int axis = 0) {
13927  return Operator("SequenceMask")
13928  .SetParam("use_sequence_length", use_sequence_length)
13929  .SetParam("value", value)
13930  .SetParam("axis", axis)
13931  .SetInput("data", data)
13932  .SetInput("sequence_length", sequence_length)
13933  .CreateSymbol();
13934 }
13935 
13944  Symbol rhs) {
13945  return Operator("choose_element_0index")
13946  .SetInput("lhs", lhs)
13947  .SetInput("rhs", rhs)
13948  .CreateSymbol();
13949 }
13950 
13960  Symbol mhs,
13961  Symbol rhs) {
13962  return Operator("fill_element_0index")
13963  .SetInput("lhs", lhs)
13964  .SetInput("mhs", mhs)
13965  .SetInput("rhs", rhs)
13966  .CreateSymbol();
13967 }
13968 
13969 } //namespace cpp
13970 } //namespace mxnet
13971 #endif // MXNET_CPP_OP_H_
Symbol Convolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, ConvolutionCudnnTune cudnn_tune=ConvolutionCudnnTune::kNone, bool cudnn_off=false, ConvolutionLayout layout=ConvolutionLayout::kNone)
Definition: op.h:4349
Symbol fix(const std::string &symbol_name, Symbol data)
Definition: op.h:1830
Symbol Crop(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, Shape offset=Shape(0, 0), Shape h_w=Shape(0, 0), bool center_crop=false)
Definition: op.h:6355
Symbol broadcast_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1003
Symbol arcsin(const std::string &symbol_name, Symbol data)
Definition: op.h:3611
Symbol FullyConnected(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, int num_hidden, bool no_bias=false, bool flatten=true)
Definition: op.h:4757
Symbol arccosh(const std::string &symbol_name, Symbol data)
Definition: op.h:3828
Symbol arctan(const std::string &symbol_name, Symbol data)
Definition: op.h:3663
Symbol SwapAxis(const std::string &symbol_name, Symbol data, uint32_t dim1=0, uint32_t dim2=0)
Definition: op.h:4985
Symbol cast_storage(const std::string &symbol_name, Symbol data, Cast_storageStype stype)
Definition: op.h:3498
Symbol add_n(const std::string &symbol_name, const std::vector< Symbol > &args)
Definition: op.h:1097
Symbol log1p(const std::string &symbol_name, Symbol data)
Definition: op.h:2077
SoftmaxActivationMode
Definition: op.h:4669
Symbol mp_sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol weight32, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:5490
Symbol SpatialTransformer(const std::string &symbol_name, Symbol data, Symbol loc, SpatialTransformerTransformType transform_type, SpatialTransformerSamplerType sampler_type, Shape target_shape=Shape(0, 0))
Definition: op.h:6481
Symbol slice(const std::string &symbol_name, Symbol data, Shape begin, Shape end, Shape step=Shape())
Definition: op.h:481
Symbol exp(const std::string &symbol_name, Symbol data)
Definition: op.h:1989
Symbol transpose(const std::string &symbol_name, Symbol data, Shape axes=Shape())
Definition: op.h:393
Symbol clip(const std::string &symbol_name, Symbol data, mx_float a_min, mx_float a_max)
Definition: op.h:645
Symbol elemwise_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2812
Symbol ROIPooling(const std::string &symbol_name, Symbol data, Symbol rois, Shape pooled_size, mx_float spatial_scale)
Definition: op.h:6843
Symbol broadcast_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1037
Convolution_v1Layout
Definition: op.h:6251
Symbol argmin(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1177
Symbol SequenceReverse(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, int axis=0)
Definition: op.h:6446
Symbol mp_sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, Symbol weight32, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:5525
Symbol broadcast_lesser(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3325
Symbol fill_element_0index(const std::string &symbol_name, Symbol lhs, Symbol mhs, Symbol rhs)
Definition: op.h:7291
Symbol Convolution_v1(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, Convolution_v1CudnnTune cudnn_tune=Convolution_v1CudnnTune::kNone, bool cudnn_off=false, Convolution_v1Layout layout=Convolution_v1Layout::kNone)
Definition: op.h:6289
Symbol broadcast_not_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3232
TakeMode
Definition: op.h:2912
Symbol Embedding(const std::string &symbol_name, Symbol data, Symbol weight, int input_dim, int output_dim, EmbeddingDtype dtype=EmbeddingDtype::kFloat32)
Definition: op.h:2884
Symbol SequenceLast(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, int axis=0)
Definition: op.h:6911
Symbol ftrl_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol z, Symbol n, mx_float lr, mx_float lamda1=0.01, mx_float beta=1, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:5878
Symbol reciprocal(const std::string &symbol_name, Symbol data)
Definition: op.h:1610
TopkRetTyp
Definition: op.h:2573
Symbol RNN(const std::string &symbol_name, Symbol data, Symbol parameters, Symbol state, Symbol state_cell, uint32_t state_size, uint32_t num_layers, RNNMode mode, bool bidirectional=false, mx_float p=0, bool state_outputs=false)
Definition: op.h:6201
namespace of mxnet
Definition: base.h:118
Symbol reshape_like(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1515
Pooling_v1PoolingConvention
Definition: op.h:6093
Symbol broadcast_lesser_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3356
Operator & SetInput(const std::string &name, Symbol symbol)
add an input symbol
Symbol InstanceNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001)
Definition: op.h:6037
Symbol sign(const std::string &symbol_name, Symbol data)
Definition: op.h:1662
GridGeneratorTransformType
Definition: op.h:6053
Cast_storageStype
Definition: op.h:3447
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
Symbol ones_like(const std::string &symbol_name, Symbol data)
Definition: op.h:899
RNNMode
Definition: op.h:6179
PadMode
Definition: op.h:4777
Symbol smooth_l1(const std::string &symbol_name, Symbol data, mx_float scalar)
Definition: op.h:3436
Symbol where(const std::string &symbol_name, Symbol condition, Symbol x, Symbol y)
Definition: op.h:3400
Symbol expm1(const std::string &symbol_name, Symbol data)
Definition: op.h:2101
Symbol elemwise_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2741
PoolingPoolType
Definition: op.h:3860
Symbol relu(const std::string &symbol_name, Symbol data)
Definition: op.h:1381
Symbol reverse(const std::string &symbol_name, Symbol data, Shape axis)
Definition: op.h:778
Symbol rsqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1915
SpatialTransformerTransformType
Definition: op.h:6461
ActivationActType
Definition: op.h:4095
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1889
Symbol Softmax(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxNormalization normalization=SoftmaxNormalization::kNull, bool out_grad=false, mx_float smooth_alpha=0)
Definition: op.h:6670
Symbol rint(const std::string &symbol_name, Symbol data)
Definition: op.h:1718
Symbol IdentityAttachKLSparseReg(const std::string &symbol_name, Symbol data, mx_float sparseness_target=0.1, mx_float penalty=0.001, mx_float momentum=0.9)
Definition: op.h:5251
Symbol sinh(const std::string &symbol_name, Symbol data)
Definition: op.h:3738
Symbol scatter_nd(const std::string &symbol_name, Symbol data, Symbol indices, Shape shape)
Definition: op.h:3168
Symbol broadcast_greater_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3294
Symbol LRN(const std::string &symbol_name, Symbol data, uint32_t nsize, mx_float alpha=0.0001, mx_float beta=0.75, mx_float knorm=2)
Definition: op.h:4587
Symbol sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:5455
Symbol LayerNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, int axis=-1, mx_float eps=1e-05, bool output_mean_var=false)
Definition: op.h:4543
Symbol arcsinh(const std::string &symbol_name, Symbol data)
Definition: op.h:3808
Symbol sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:5393
Symbol MAERegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:5190
Symbol SliceChannel(const std::string &symbol_name, Symbol data, int num_outputs, int axis=1, bool squeeze_axis=false)
Definition: op.h:5974
PoolingPoolingConvention
Definition: op.h:3868
Symbol broadcast_minimum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:180
Symbol broadcast_maximum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:147
Symbol Cast(const std::string &symbol_name, Symbol data, CastDtype dtype)
Definition: op.h:1555
DeconvolutionLayout
Definition: op.h:4004
Symbol trunc(const std::string &symbol_name, Symbol data)
Definition: op.h:1803
Pooling_v1PoolType
Definition: op.h:6085
Symbol round(const std::string &symbol_name, Symbol data)
Definition: op.h:1688
Symbol Dropout(const std::string &symbol_name, Symbol data, mx_float p=0.5, DropoutMode mode=DropoutMode::kTraining, Shape axes=Shape())
Definition: op.h:4649
Symbol squeeze(const std::string &symbol_name, const std::vector< Symbol > &data, dmlc::optional< Shape > axis=dmlc::optional< Shape >())
Definition: op.h:843
Symbol log_softmax(const std::string &symbol_name, Symbol data, int axis=-1)
Definition: op.h:3984
Symbol khatri_rao(const std::string &symbol_name, const std::vector< Symbol > &args)
Definition: op.h:62
Symbol cos(const std::string &symbol_name, Symbol data)
Definition: op.h:3556
Symbol L2Normalization(const std::string &symbol_name, Symbol data, mx_float eps=1e-10, L2NormalizationMode mode=L2NormalizationMode::kInstance)
Definition: op.h:6994
Symbol max(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2391
Symbol Correlation(const std::string &symbol_name, Symbol data1, Symbol data2, uint32_t kernel_size=1, uint32_t max_displacement=1, uint32_t stride1=1, uint32_t stride2=1, uint32_t pad_size=0, bool is_multiply=true)
Definition: op.h:7151
Symbol zeros_like(const std::string &symbol_name, Symbol data)
Definition: op.h:875
EmbeddingDtype
Definition: op.h:2823
Symbol batch_dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false)
Definition: op.h:1350
Symbol broadcast_mod(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1068
Symbol cbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1939
operator helper functions
Symbol tanh(const std::string &symbol_name, Symbol data)
Definition: op.h:3785
Symbol broadcast_to(const std::string &symbol_name, Symbol data, Shape shape=Shape())
Definition: op.h:2509
Symbol elemwise_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2765
DropoutMode
Definition: op.h:4604
Symbol norm(const std::string &symbol_name, Symbol data, int ord=2, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false)
Definition: op.h:2555
Symbol MakeLoss(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float valid_thresh=0, MakeLossNormalization normalization=MakeLossNormalization::kNull)
Definition: op.h:7054
Symbol log(const std::string &symbol_name, Symbol data)
Definition: op.h:2010
Symbol sigmoid(const std::string &symbol_name, Symbol data)
Definition: op.h:1403
CastDtype
Definition: op.h:1526
ConvolutionLayout
Definition: op.h:4243
Symbol LogisticRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:5231
Symbol gamma(const std::string &symbol_name, Symbol data)
Definition: op.h:2119
Symbol sin(const std::string &symbol_name, Symbol data)
Definition: op.h:3532
UpSamplingMultiInputMode
Definition: op.h:4406
Symbol CreateSymbol(const std::string &name="")
create a Symbol from the current operator
Symbol elemwise_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2792
SpatialTransformerSamplerType
Definition: op.h:6467
Symbol Pad(const std::string &symbol_name, Symbol data, PadMode mode, Shape pad_width, double constant_value=0)
Definition: op.h:4879
Symbol square(const std::string &symbol_name, Symbol data)
Definition: op.h:1860
Symbol LeakyReLU(const std::string &symbol_name, Symbol data, Symbol gamma, LeakyReLUActType act_type=LeakyReLUActType::kLeaky, mx_float slope=0.25, mx_float lower_bound=0.125, mx_float upper_bound=0.334)
Definition: op.h:4934
One_hotDtype
Definition: op.h:3018
Symbol nansum(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2315
UpSamplingSampleType
Definition: op.h:4398
Symbol rmsprop_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, mx_float lr, mx_float gamma1=0.95, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:5738
Symbol make_loss(const std::string &symbol_name, Symbol data)
Definition: op.h:1501
Symbol SoftmaxActivation(const std::string &symbol_name, Symbol data, SoftmaxActivationMode mode=SoftmaxActivationMode::kInstance)
Definition: op.h:4707
Symbol broadcast_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3201
Symbol nanprod(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2354
Symbol Deconvolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), Shape adj=Shape(), Shape target_shape=Shape(), uint32_t num_group=1, uint64_t workspace=512, bool no_bias=true, DeconvolutionCudnnTune cudnn_tune=DeconvolutionCudnnTune::kNone, bool cudnn_off=false, DeconvolutionLayout layout=DeconvolutionLayout::kNone)
Definition: op.h:4042
Symbol broadcast_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:933
Symbol adam_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mean, Symbol var, mx_float lr, mx_float beta1=0.9, mx_float beta2=0.999, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:5658
Operator & SetParam(const std::string &name, const T &value)
set config parameters
Definition: operator.h:58
Symbol tan(const std::string &symbol_name, Symbol data)
Definition: op.h:3583
Convolution_v1CudnnTune
Definition: op.h:6241
Symbol repeat(const std::string &symbol_name, Symbol data, int repeats, dmlc::optional< int > axis=dmlc::optional< int >())
Definition: op.h:690
Symbol slice_axis(const std::string &symbol_name, Symbol data, int axis, int begin, dmlc::optional< int > end)
Definition: op.h:526
Symbol expand_dims(const std::string &symbol_name, Symbol data, int axis)
Definition: op.h:417
Symbol arctanh(const std::string &symbol_name, Symbol data)
Definition: op.h:3851
Symbol softmax_cross_entropy(const std::string &symbol_name, Symbol data, Symbol label)
Definition: op.h:5111
Symbol pick(const std::string &symbol_name, Symbol data, Symbol index, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1262
Symbol broadcast_axis(const std::string &symbol_name, Symbol data, Shape axis=Shape(), Shape size=Shape())
Definition: op.h:2470
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:1636
Symbol cosh(const std::string &symbol_name, Symbol data)
Definition: op.h:3760
Symbol sort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Definition: op.h:2674
Symbol gather_nd(const std::string &symbol_name, Symbol data, Symbol indices)
Definition: op.h:3123
Symbol slice_like(const std::string &symbol_name, Symbol data, Symbol shape_like, Shape axes=Shape())
Definition: op.h:600
Symbol Pooling(const std::string &symbol_name, Symbol data, Shape kernel=Shape(), PoolingPoolType pool_type=PoolingPoolType::kMax, bool global_pool=false, bool cudnn_off=false, PoolingPoolingConvention pooling_convention=PoolingPoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:3894
Symbol BilinearSampler(const std::string &symbol_name, Symbol data, Symbol grid)
Definition: op.h:6778
Symbol Custom(const std::string &symbol_name, const std::vector< Symbol > &data, const std::string &op_type)
Definition: op.h:84
Symbol Pooling_v1(const std::string &symbol_name, Symbol data, Shape kernel=Shape(), Pooling_v1PoolType pool_type=Pooling_v1PoolType::kMax, bool global_pool=false, Pooling_v1PoolingConvention pooling_convention=Pooling_v1PoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:6149
Symbol broadcast_hypot(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:219
Symbol BatchNorm_v1(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false)
Definition: op.h:5053
Symbol UpSampling(const std::string &symbol_name, const std::vector< Symbol > &data, uint32_t scale, UpSamplingSampleType sample_type, int num_args, uint32_t num_filter=0, UpSamplingMultiInputMode multi_input_mode=UpSamplingMultiInputMode::kConcat, uint64_t workspace=512)
Definition: op.h:4426
Symbol Activation(const std::string &symbol_name, Symbol data, ActivationActType act_type)
Definition: op.h:4122
float mx_float
manually define float
Definition: c_api.h:60
Symbol SVMOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float margin=1, mx_float regularization_coefficient=1, bool use_linear=false)
Definition: op.h:7087
Symbol radians(const std::string &symbol_name, Symbol data)
Definition: op.h:3713
Symbol Concat(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int dim=1)
Definition: op.h:4494
Symbol ftml_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol d, Symbol v, Symbol z, mx_float lr, mx_float beta1=0.9, mx_float beta2=0.999, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:5582
L2NormalizationMode
Definition: op.h:6926
Symbol SequenceMask(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, mx_float value=0, int axis=0)
Definition: op.h:7250
Symbol stack(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int axis=0)
Definition: op.h:810
Symbol floor(const std::string &symbol_name, Symbol data)
Definition: op.h:1774
Symbol broadcast_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:969
Symbol take(const std::string &symbol_name, Symbol a, Symbol indices, int axis=0, TakeMode mode=TakeMode::kClip)
Definition: op.h:2961
Symbol ceil(const std::string &symbol_name, Symbol data)
Definition: op.h:1746
Symbol gammaln(const std::string &symbol_name, Symbol data)
Definition: op.h:2137
Symbol tile(const std::string &symbol_name, Symbol data, Shape reps)
Definition: op.h:746
Symbol min(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2428
Symbol signum_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float wd_lh=0)
Definition: op.h:5343
Symbol rmspropalex_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, Symbol g, Symbol delta, mx_float lr, mx_float gamma1=0.95, mx_float gamma2=0.9, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:5808
Symbol argsort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Definition: op.h:2715
SoftmaxNormalization
Definition: op.h:6637
Symbol softmax(const std::string &symbol_name, Symbol data, int axis=-1)
Definition: op.h:3953
DeconvolutionCudnnTune
Definition: op.h:3995
ConvolutionCudnnTune
Definition: op.h:4233
Symbol prod(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2276
definition of shape
Symbol broadcast_greater(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3263
Symbol BatchNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, Symbol moving_mean, Symbol moving_var, double eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false, int axis=1, bool cudnn_off=false)
Definition: op.h:4202
Symbol rcbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1963
Symbol topk(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), int k=1, TopkRetTyp ret_typ=TopkRetTyp::kIndices, bool is_ascend=false)
Definition: op.h:2621
Symbol signsgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:5292
Symbol broadcast_power(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:114
SoftmaxOutputNormalization
Definition: op.h:6504
Symbol SoftmaxOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxOutputNormalization normalization=SoftmaxOutputNormalization::kNull, bool out_grad=false, mx_float smooth_alpha=0)
Definition: op.h:6605
Symbol Flatten(const std::string &symbol_name, Symbol data)
Definition: op.h:350
Symbol BlockGrad(const std::string &symbol_name, Symbol data)
Definition: op.h:1465
LeakyReLUActType
Definition: op.h:4899
Symbol arccos(const std::string &symbol_name, Symbol data)
Definition: op.h:3636
Symbol argmax_channel(const std::string &symbol_name, Symbol data)
Definition: op.h:1210
Symbol batch_take(const std::string &symbol_name, Symbol a, Symbol indices)
Definition: op.h:3007
Symbol mean(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2239
Symbol LinearRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:5149
Symbol softsign(const std::string &symbol_name, Symbol data)
Definition: op.h:1425
Symbol choose_element_0index(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:7273
Symbol Reshape(const std::string &symbol_name, Symbol data, Shape shape=Shape(), bool reverse=false, Shape target_shape=Shape(), bool keep_highest=false)
Definition: op.h:302
Symbol degrees(const std::string &symbol_name, Symbol data)
Definition: op.h:3688
Symbol one_hot(const std::string &symbol_name, Symbol indices, int depth, double on_value=1, double off_value=0, One_hotDtype dtype=One_hotDtype::kFloat32)
Definition: op.h:3072
Symbol negative(const std::string &symbol_name, Symbol data)
Definition: op.h:1587
Symbol sum(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2202
Symbol GridGenerator(const std::string &symbol_name, Symbol data, GridGeneratorTransformType transform_type, Shape target_shape=Shape(0, 0))
Definition: op.h:6068
Symbol argmax(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1135
Operator interface.
Definition: operator.h:43
Symbol interface.
Definition: symbol.h:72
MakeLossNormalization
Definition: op.h:7014
Symbol log10(const std::string &symbol_name, Symbol data)
Definition: op.h:2031
Symbol dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false)
Definition: op.h:1315
Symbol log2(const std::string &symbol_name, Symbol data)
Definition: op.h:2052