mxnet
op.h
Go to the documentation of this file.
1 
8 #ifndef MXNET_CPP_OP_H_
9 #define MXNET_CPP_OP_H_
10 
11 #include <string>
12 #include <vector>
13 #include "mxnet-cpp/base.h"
14 #include "mxnet-cpp/shape.h"
15 #include "mxnet-cpp/op_util.h"
16 #include "mxnet-cpp/operator.h"
17 #include "dmlc/optional.h"
18 
19 namespace mxnet {
20 namespace cpp {
21 
62 inline Symbol khatri_rao(const std::string& symbol_name,
63  const std::vector<Symbol>& args) {
64  return Operator("khatri_rao")
65 (args)
66  .CreateSymbol(symbol_name);
67 }
68 
84 inline Symbol Custom(const std::string& symbol_name,
85  const std::vector<Symbol>& data,
86  const std::string& op_type) {
87  return Operator("Custom")
88 (data)
89  .CreateSymbol(symbol_name);
90 }
91 
114 inline Symbol broadcast_power(const std::string& symbol_name,
115  Symbol lhs,
116  Symbol rhs) {
117  return Operator("broadcast_power")
118  .SetInput("lhs", lhs)
119  .SetInput("rhs", rhs)
120  .CreateSymbol(symbol_name);
121 }
122 
147 inline Symbol broadcast_maximum(const std::string& symbol_name,
148  Symbol lhs,
149  Symbol rhs) {
150  return Operator("broadcast_maximum")
151  .SetInput("lhs", lhs)
152  .SetInput("rhs", rhs)
153  .CreateSymbol(symbol_name);
154 }
155 
180 inline Symbol broadcast_minimum(const std::string& symbol_name,
181  Symbol lhs,
182  Symbol rhs) {
183  return Operator("broadcast_minimum")
184  .SetInput("lhs", lhs)
185  .SetInput("rhs", rhs)
186  .CreateSymbol(symbol_name);
187 }
188 
219 inline Symbol broadcast_hypot(const std::string& symbol_name,
220  Symbol lhs,
221  Symbol rhs) {
222  return Operator("broadcast_hypot")
223  .SetInput("lhs", lhs)
224  .SetInput("rhs", rhs)
225  .CreateSymbol(symbol_name);
226 }
227 
302 inline Symbol Reshape(const std::string& symbol_name,
303  Symbol data,
304  Shape shape = Shape(),
305  bool reverse = false,
306  Shape target_shape = Shape(),
307  bool keep_highest = false) {
308  return Operator("Reshape")
309  .SetParam("shape", shape)
310  .SetParam("reverse", reverse)
311  .SetParam("target_shape", target_shape)
312  .SetParam("keep_highest", keep_highest)
313  .SetInput("data", data)
314  .CreateSymbol(symbol_name);
315 }
316 
350 inline Symbol Flatten(const std::string& symbol_name,
351  Symbol data) {
352  return Operator("Flatten")
353  .SetInput("data", data)
354  .CreateSymbol(symbol_name);
355 }
356 
393 inline Symbol transpose(const std::string& symbol_name,
394  Symbol data,
395  Shape axes = Shape()) {
396  return Operator("transpose")
397  .SetParam("axes", axes)
398  .SetInput("data", data)
399  .CreateSymbol(symbol_name);
400 }
401 
417 inline Symbol expand_dims(const std::string& symbol_name,
418  Symbol data,
419  int axis) {
420  return Operator("expand_dims")
421  .SetParam("axis", axis)
422  .SetInput("data", data)
423  .CreateSymbol(symbol_name);
424 }
425 
481 inline Symbol slice(const std::string& symbol_name,
482  Symbol data,
483  Shape begin,
484  Shape end,
485  Shape step = Shape()) {
486  return Operator("slice")
487  .SetParam("begin", begin)
488  .SetParam("end", end)
489  .SetParam("step", step)
490  .SetInput("data", data)
491  .CreateSymbol(symbol_name);
492 }
493 
526 inline Symbol slice_axis(const std::string& symbol_name,
527  Symbol data,
528  int axis,
529  int begin,
530  dmlc::optional<int> end) {
531  return Operator("slice_axis")
532  .SetParam("axis", axis)
533  .SetParam("begin", begin)
534  .SetParam("end", end)
535  .SetInput("data", data)
536  .CreateSymbol(symbol_name);
537 }
538 
600 inline Symbol slice_like(const std::string& symbol_name,
601  Symbol data,
602  Symbol shape_like,
603  Shape axes = Shape()) {
604  return Operator("slice_like")
605  .SetParam("axes", axes)
606  .SetInput("data", data)
607  .SetInput("shape_like", shape_like)
608  .CreateSymbol(symbol_name);
609 }
610 
645 inline Symbol clip(const std::string& symbol_name,
646  Symbol data,
647  mx_float a_min,
648  mx_float a_max) {
649  return Operator("clip")
650  .SetParam("a_min", a_min)
651  .SetParam("a_max", a_max)
652  .SetInput("data", data)
653  .CreateSymbol(symbol_name);
654 }
655 
690 inline Symbol repeat(const std::string& symbol_name,
691  Symbol data,
692  int repeats,
693  dmlc::optional<int> axis = dmlc::optional<int>()) {
694  return Operator("repeat")
695  .SetParam("repeats", repeats)
696  .SetParam("axis", axis)
697  .SetInput("data", data)
698  .CreateSymbol(symbol_name);
699 }
700 
746 inline Symbol tile(const std::string& symbol_name,
747  Symbol data,
748  Shape reps) {
749  return Operator("tile")
750  .SetParam("reps", reps)
751  .SetInput("data", data)
752  .CreateSymbol(symbol_name);
753 }
754 
778 inline Symbol reverse(const std::string& symbol_name,
779  Symbol data,
780  Shape axis) {
781  return Operator("reverse")
782  .SetParam("axis", axis)
783  .SetInput("data", data)
784  .CreateSymbol(symbol_name);
785 }
786 
810 inline Symbol stack(const std::string& symbol_name,
811  const std::vector<Symbol>& data,
812  int num_args,
813  int axis = 0) {
814  return Operator("stack")
815  .SetParam("num_args", num_args)
816  .SetParam("axis", axis)
817 (data)
818  .CreateSymbol(symbol_name);
819 }
820 
843 inline Symbol squeeze(const std::string& symbol_name,
844  const std::vector<Symbol>& data,
845  dmlc::optional<Shape> axis = dmlc::optional<Shape>()) {
846  return Operator("squeeze")
847  .SetParam("axis", axis)
848 (data)
849  .CreateSymbol(symbol_name);
850 }
851 
893 inline Symbol depth_to_space(const std::string& symbol_name,
894  Symbol data,
895  int block_size) {
896  return Operator("depth_to_space")
897  .SetParam("block_size", block_size)
898  .SetInput("data", data)
899  .CreateSymbol(symbol_name);
900 }
901 
945 inline Symbol space_to_depth(const std::string& symbol_name,
946  Symbol data,
947  int block_size) {
948  return Operator("space_to_depth")
949  .SetParam("block_size", block_size)
950  .SetInput("data", data)
951  .CreateSymbol(symbol_name);
952 }
953 
977 inline Symbol zeros_like(const std::string& symbol_name,
978  Symbol data) {
979  return Operator("zeros_like")
980  .SetInput("data", data)
981  .CreateSymbol(symbol_name);
982 }
983 
1001 inline Symbol ones_like(const std::string& symbol_name,
1002  Symbol data) {
1003  return Operator("ones_like")
1004  .SetInput("data", data)
1005  .CreateSymbol(symbol_name);
1006 }
1007 
1030 inline Symbol add_n(const std::string& symbol_name,
1031  const std::vector<Symbol>& args) {
1032  return Operator("add_n")
1033 (args)
1034  .CreateSymbol(symbol_name);
1035 }
1036 
1068 inline Symbol argmax(const std::string& symbol_name,
1069  Symbol data,
1070  dmlc::optional<int> axis = dmlc::optional<int>(),
1071  bool keepdims = false) {
1072  return Operator("argmax")
1073  .SetParam("axis", axis)
1074  .SetParam("keepdims", keepdims)
1075  .SetInput("data", data)
1076  .CreateSymbol(symbol_name);
1077 }
1078 
1110 inline Symbol argmin(const std::string& symbol_name,
1111  Symbol data,
1112  dmlc::optional<int> axis = dmlc::optional<int>(),
1113  bool keepdims = false) {
1114  return Operator("argmin")
1115  .SetParam("axis", axis)
1116  .SetParam("keepdims", keepdims)
1117  .SetInput("data", data)
1118  .CreateSymbol(symbol_name);
1119 }
1120 
1143 inline Symbol argmax_channel(const std::string& symbol_name,
1144  Symbol data) {
1145  return Operator("argmax_channel")
1146  .SetInput("data", data)
1147  .CreateSymbol(symbol_name);
1148 }
1149 
1154 enum class PickMode {
1155  kClip = 0,
1156  kWrap = 1
1157 };
1158 
1215 inline Symbol pick(const std::string& symbol_name,
1216  Symbol data,
1217  Symbol index,
1218  dmlc::optional<int> axis = dmlc::optional<int>(-1),
1219  bool keepdims = false,
1220  PickMode mode = PickMode::kClip) {
1221  static const char *PickModeValues[] = {
1222  "clip",
1223  "wrap"
1224  };
1225  return Operator("pick")
1226  .SetParam("axis", axis)
1227  .SetParam("keepdims", keepdims)
1228  .SetParam("mode", PickModeValues[int(mode)])
1229  .SetInput("data", data)
1230  .SetInput("index", index)
1231  .CreateSymbol(symbol_name);
1232 }
1233 
1238 enum class DotForwardStype {
1239  kNone = 0,
1240  kCsr = 1,
1241  kDefault = 2,
1242  kRow_sparse = 3
1243 };
1244 
1303 inline Symbol dot(const std::string& symbol_name,
1304  Symbol lhs,
1305  Symbol rhs,
1306  bool transpose_a = false,
1307  bool transpose_b = false,
1308  DotForwardStype forward_stype = DotForwardStype::kNone) {
1309  static const char *DotForwardStypeValues[] = {
1310  "None",
1311  "csr",
1312  "default",
1313  "row_sparse"
1314  };
1315  return Operator("dot")
1316  .SetParam("transpose_a", transpose_a)
1317  .SetParam("transpose_b", transpose_b)
1318  .SetParam("forward_stype", DotForwardStypeValues[int(forward_stype)])
1319  .SetInput("lhs", lhs)
1320  .SetInput("rhs", rhs)
1321  .CreateSymbol(symbol_name);
1322 }
1323 
1329  kNone = 0,
1330  kCsr = 1,
1331  kDefault = 2,
1332  kRow_sparse = 3
1333 };
1334 
1360 inline Symbol batch_dot(const std::string& symbol_name,
1361  Symbol lhs,
1362  Symbol rhs,
1363  bool transpose_a = false,
1364  bool transpose_b = false,
1366  static const char *Batch_dotForwardStypeValues[] = {
1367  "None",
1368  "csr",
1369  "default",
1370  "row_sparse"
1371  };
1372  return Operator("batch_dot")
1373  .SetParam("transpose_a", transpose_a)
1374  .SetParam("transpose_b", transpose_b)
1375  .SetParam("forward_stype", Batch_dotForwardStypeValues[int(forward_stype)])
1376  .SetInput("lhs", lhs)
1377  .SetInput("rhs", rhs)
1378  .CreateSymbol(symbol_name);
1379 }
1380 
1413 inline Symbol broadcast_add(const std::string& symbol_name,
1414  Symbol lhs,
1415  Symbol rhs) {
1416  return Operator("broadcast_add")
1417  .SetInput("lhs", lhs)
1418  .SetInput("rhs", rhs)
1419  .CreateSymbol(symbol_name);
1420 }
1421 
1454 inline Symbol broadcast_sub(const std::string& symbol_name,
1455  Symbol lhs,
1456  Symbol rhs) {
1457  return Operator("broadcast_sub")
1458  .SetInput("lhs", lhs)
1459  .SetInput("rhs", rhs)
1460  .CreateSymbol(symbol_name);
1461 }
1462 
1489 inline Symbol broadcast_mul(const std::string& symbol_name,
1490  Symbol lhs,
1491  Symbol rhs) {
1492  return Operator("broadcast_mul")
1493  .SetInput("lhs", lhs)
1494  .SetInput("rhs", rhs)
1495  .CreateSymbol(symbol_name);
1496 }
1497 
1524 inline Symbol broadcast_div(const std::string& symbol_name,
1525  Symbol lhs,
1526  Symbol rhs) {
1527  return Operator("broadcast_div")
1528  .SetInput("lhs", lhs)
1529  .SetInput("rhs", rhs)
1530  .CreateSymbol(symbol_name);
1531 }
1532 
1555 inline Symbol broadcast_mod(const std::string& symbol_name,
1556  Symbol lhs,
1557  Symbol rhs) {
1558  return Operator("broadcast_mod")
1559  .SetInput("lhs", lhs)
1560  .SetInput("rhs", rhs)
1561  .CreateSymbol(symbol_name);
1562 }
1563 
1583 inline Symbol relu(const std::string& symbol_name,
1584  Symbol data) {
1585  return Operator("relu")
1586  .SetInput("data", data)
1587  .CreateSymbol(symbol_name);
1588 }
1589 
1605 inline Symbol sigmoid(const std::string& symbol_name,
1606  Symbol data) {
1607  return Operator("sigmoid")
1608  .SetInput("data", data)
1609  .CreateSymbol(symbol_name);
1610 }
1611 
1627 inline Symbol hard_sigmoid(const std::string& symbol_name,
1628  Symbol data,
1629  mx_float alpha = 0.2,
1630  mx_float beta = 0.5) {
1631  return Operator("hard_sigmoid")
1632  .SetParam("alpha", alpha)
1633  .SetParam("beta", beta)
1634  .SetInput("data", data)
1635  .CreateSymbol(symbol_name);
1636 }
1637 
1653 inline Symbol softsign(const std::string& symbol_name,
1654  Symbol data) {
1655  return Operator("softsign")
1656  .SetInput("data", data)
1657  .CreateSymbol(symbol_name);
1658 }
1659 
1693 inline Symbol BlockGrad(const std::string& symbol_name,
1694  Symbol data) {
1695  return Operator("BlockGrad")
1696  .SetInput("data", data)
1697  .CreateSymbol(symbol_name);
1698 }
1699 
1729 inline Symbol make_loss(const std::string& symbol_name,
1730  Symbol data) {
1731  return Operator("make_loss")
1732  .SetInput("data", data)
1733  .CreateSymbol(symbol_name);
1734 }
1735 
1770 inline Symbol reshape_like(const std::string& symbol_name,
1771  Symbol lhs,
1772  Symbol rhs) {
1773  return Operator("reshape_like")
1774  .SetInput("lhs", lhs)
1775  .SetInput("rhs", rhs)
1776  .CreateSymbol(symbol_name);
1777 }
1778 
1797 inline Symbol shape_array(const std::string& symbol_name,
1798  Symbol data,
1799  dmlc::optional<int> lhs_begin = dmlc::optional<int>(),
1800  dmlc::optional<int> lhs_end = dmlc::optional<int>(),
1801  dmlc::optional<int> rhs_begin = dmlc::optional<int>(),
1802  dmlc::optional<int> rhs_end = dmlc::optional<int>()) {
1803  return Operator("shape_array")
1804  .SetParam("lhs_begin", lhs_begin)
1805  .SetParam("lhs_end", lhs_end)
1806  .SetParam("rhs_begin", rhs_begin)
1807  .SetParam("rhs_end", rhs_end)
1808  .SetInput("data", data)
1809  .CreateSymbol(symbol_name);
1810 }
1811 
1826 inline Symbol size_array(const std::string& symbol_name,
1827  Symbol data) {
1828  return Operator("size_array")
1829  .SetInput("data", data)
1830  .CreateSymbol(symbol_name);
1831 }
1832 
1835 enum class CastDtype {
1836  kFloat16 = 0,
1837  kFloat32 = 1,
1838  kFloat64 = 2,
1839  kInt32 = 3,
1840  kInt64 = 4,
1841  kInt8 = 5,
1842  kUint8 = 6
1843 };
1844 
1864 inline Symbol Cast(const std::string& symbol_name,
1865  Symbol data,
1866  CastDtype dtype) {
1867  static const char *CastDtypeValues[] = {
1868  "float16",
1869  "float32",
1870  "float64",
1871  "int32",
1872  "int64",
1873  "int8",
1874  "uint8"
1875  };
1876  return Operator("Cast")
1877  .SetParam("dtype", CastDtypeValues[int(dtype)])
1878  .SetInput("data", data)
1879  .CreateSymbol(symbol_name);
1880 }
1881 
1896 inline Symbol negative(const std::string& symbol_name,
1897  Symbol data) {
1898  return Operator("negative")
1899  .SetInput("data", data)
1900  .CreateSymbol(symbol_name);
1901 }
1902 
1919 inline Symbol reciprocal(const std::string& symbol_name,
1920  Symbol data) {
1921  return Operator("reciprocal")
1922  .SetInput("data", data)
1923  .CreateSymbol(symbol_name);
1924 }
1925 
1946 inline Symbol abs(const std::string& symbol_name,
1947  Symbol data) {
1948  return Operator("abs")
1949  .SetInput("data", data)
1950  .CreateSymbol(symbol_name);
1951 }
1952 
1973 inline Symbol sign(const std::string& symbol_name,
1974  Symbol data) {
1975  return Operator("sign")
1976  .SetInput("data", data)
1977  .CreateSymbol(symbol_name);
1978 }
1979 
2000 inline Symbol round(const std::string& symbol_name,
2001  Symbol data) {
2002  return Operator("round")
2003  .SetInput("data", data)
2004  .CreateSymbol(symbol_name);
2005 }
2006 
2031 inline Symbol rint(const std::string& symbol_name,
2032  Symbol data) {
2033  return Operator("rint")
2034  .SetInput("data", data)
2035  .CreateSymbol(symbol_name);
2036 }
2037 
2060 inline Symbol ceil(const std::string& symbol_name,
2061  Symbol data) {
2062  return Operator("ceil")
2063  .SetInput("data", data)
2064  .CreateSymbol(symbol_name);
2065 }
2066 
2089 inline Symbol floor(const std::string& symbol_name,
2090  Symbol data) {
2091  return Operator("floor")
2092  .SetInput("data", data)
2093  .CreateSymbol(symbol_name);
2094 }
2095 
2119 inline Symbol trunc(const std::string& symbol_name,
2120  Symbol data) {
2121  return Operator("trunc")
2122  .SetInput("data", data)
2123  .CreateSymbol(symbol_name);
2124 }
2125 
2147 inline Symbol fix(const std::string& symbol_name,
2148  Symbol data) {
2149  return Operator("fix")
2150  .SetInput("data", data)
2151  .CreateSymbol(symbol_name);
2152 }
2153 
2177 inline Symbol square(const std::string& symbol_name,
2178  Symbol data) {
2179  return Operator("square")
2180  .SetInput("data", data)
2181  .CreateSymbol(symbol_name);
2182 }
2183 
2207 inline Symbol sqrt(const std::string& symbol_name,
2208  Symbol data) {
2209  return Operator("sqrt")
2210  .SetInput("data", data)
2211  .CreateSymbol(symbol_name);
2212 }
2213 
2233 inline Symbol rsqrt(const std::string& symbol_name,
2234  Symbol data) {
2235  return Operator("rsqrt")
2236  .SetInput("data", data)
2237  .CreateSymbol(symbol_name);
2238 }
2239 
2263 inline Symbol cbrt(const std::string& symbol_name,
2264  Symbol data) {
2265  return Operator("cbrt")
2266  .SetInput("data", data)
2267  .CreateSymbol(symbol_name);
2268 }
2269 
2284 inline Symbol erf(const std::string& symbol_name,
2285  Symbol data) {
2286  return Operator("erf")
2287  .SetInput("data", data)
2288  .CreateSymbol(symbol_name);
2289 }
2290 
2308 inline Symbol rcbrt(const std::string& symbol_name,
2309  Symbol data) {
2310  return Operator("rcbrt")
2311  .SetInput("data", data)
2312  .CreateSymbol(symbol_name);
2313 }
2314 
2334 inline Symbol exp(const std::string& symbol_name,
2335  Symbol data) {
2336  return Operator("exp")
2337  .SetInput("data", data)
2338  .CreateSymbol(symbol_name);
2339 }
2340 
2355 inline Symbol log(const std::string& symbol_name,
2356  Symbol data) {
2357  return Operator("log")
2358  .SetInput("data", data)
2359  .CreateSymbol(symbol_name);
2360 }
2361 
2376 inline Symbol log10(const std::string& symbol_name,
2377  Symbol data) {
2378  return Operator("log10")
2379  .SetInput("data", data)
2380  .CreateSymbol(symbol_name);
2381 }
2382 
2397 inline Symbol log2(const std::string& symbol_name,
2398  Symbol data) {
2399  return Operator("log2")
2400  .SetInput("data", data)
2401  .CreateSymbol(symbol_name);
2402 }
2403 
2423 inline Symbol log1p(const std::string& symbol_name,
2424  Symbol data) {
2425  return Operator("log1p")
2426  .SetInput("data", data)
2427  .CreateSymbol(symbol_name);
2428 }
2429 
2448 inline Symbol expm1(const std::string& symbol_name,
2449  Symbol data) {
2450  return Operator("expm1")
2451  .SetInput("data", data)
2452  .CreateSymbol(symbol_name);
2453 }
2454 
2466 inline Symbol gamma(const std::string& symbol_name,
2467  Symbol data) {
2468  return Operator("gamma")
2469  .SetInput("data", data)
2470  .CreateSymbol(symbol_name);
2471 }
2472 
2484 inline Symbol gammaln(const std::string& symbol_name,
2485  Symbol data) {
2486  return Operator("gammaln")
2487  .SetInput("data", data)
2488  .CreateSymbol(symbol_name);
2489 }
2490 
2502 inline Symbol logical_not(const std::string& symbol_name,
2503  Symbol data) {
2504  return Operator("logical_not")
2505  .SetInput("data", data)
2506  .CreateSymbol(symbol_name);
2507 }
2508 
2567 inline Symbol sum(const std::string& symbol_name,
2568  Symbol data,
2569  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2570  bool keepdims = false,
2571  bool exclude = false) {
2572  return Operator("sum")
2573  .SetParam("axis", axis)
2574  .SetParam("keepdims", keepdims)
2575  .SetParam("exclude", exclude)
2576  .SetInput("data", data)
2577  .CreateSymbol(symbol_name);
2578 }
2579 
2604 inline Symbol mean(const std::string& symbol_name,
2605  Symbol data,
2606  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2607  bool keepdims = false,
2608  bool exclude = false) {
2609  return Operator("mean")
2610  .SetParam("axis", axis)
2611  .SetParam("keepdims", keepdims)
2612  .SetParam("exclude", exclude)
2613  .SetInput("data", data)
2614  .CreateSymbol(symbol_name);
2615 }
2616 
2641 inline Symbol prod(const std::string& symbol_name,
2642  Symbol data,
2643  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2644  bool keepdims = false,
2645  bool exclude = false) {
2646  return Operator("prod")
2647  .SetParam("axis", axis)
2648  .SetParam("keepdims", keepdims)
2649  .SetParam("exclude", exclude)
2650  .SetInput("data", data)
2651  .CreateSymbol(symbol_name);
2652 }
2653 
2680 inline Symbol nansum(const std::string& symbol_name,
2681  Symbol data,
2682  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2683  bool keepdims = false,
2684  bool exclude = false) {
2685  return Operator("nansum")
2686  .SetParam("axis", axis)
2687  .SetParam("keepdims", keepdims)
2688  .SetParam("exclude", exclude)
2689  .SetInput("data", data)
2690  .CreateSymbol(symbol_name);
2691 }
2692 
2719 inline Symbol nanprod(const std::string& symbol_name,
2720  Symbol data,
2721  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2722  bool keepdims = false,
2723  bool exclude = false) {
2724  return Operator("nanprod")
2725  .SetParam("axis", axis)
2726  .SetParam("keepdims", keepdims)
2727  .SetParam("exclude", exclude)
2728  .SetInput("data", data)
2729  .CreateSymbol(symbol_name);
2730 }
2731 
2756 inline Symbol max(const std::string& symbol_name,
2757  Symbol data,
2758  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2759  bool keepdims = false,
2760  bool exclude = false) {
2761  return Operator("max")
2762  .SetParam("axis", axis)
2763  .SetParam("keepdims", keepdims)
2764  .SetParam("exclude", exclude)
2765  .SetInput("data", data)
2766  .CreateSymbol(symbol_name);
2767 }
2768 
2793 inline Symbol min(const std::string& symbol_name,
2794  Symbol data,
2795  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2796  bool keepdims = false,
2797  bool exclude = false) {
2798  return Operator("min")
2799  .SetParam("axis", axis)
2800  .SetParam("keepdims", keepdims)
2801  .SetParam("exclude", exclude)
2802  .SetInput("data", data)
2803  .CreateSymbol(symbol_name);
2804 }
2805 
2835 inline Symbol broadcast_axis(const std::string& symbol_name,
2836  Symbol data,
2837  Shape axis = Shape(),
2838  Shape size = Shape()) {
2839  return Operator("broadcast_axis")
2840  .SetParam("axis", axis)
2841  .SetParam("size", size)
2842  .SetInput("data", data)
2843  .CreateSymbol(symbol_name);
2844 }
2845 
2874 inline Symbol broadcast_to(const std::string& symbol_name,
2875  Symbol data,
2876  Shape shape = Shape()) {
2877  return Operator("broadcast_to")
2878  .SetParam("shape", shape)
2879  .SetInput("data", data)
2880  .CreateSymbol(symbol_name);
2881 }
2882 
2911 inline Symbol broadcast_like(const std::string& symbol_name,
2912  Symbol lhs,
2913  Symbol rhs,
2914  dmlc::optional<Shape> lhs_axes = dmlc::optional<Shape>(),
2915  dmlc::optional<Shape> rhs_axes = dmlc::optional<Shape>()) {
2916  return Operator("broadcast_like")
2917  .SetParam("lhs_axes", lhs_axes)
2918  .SetParam("rhs_axes", rhs_axes)
2919  .SetInput("lhs", lhs)
2920  .SetInput("rhs", rhs)
2921  .CreateSymbol(symbol_name);
2922 }
2923 
2967 inline Symbol norm(const std::string& symbol_name,
2968  Symbol data,
2969  int ord = 2,
2970  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2971  bool keepdims = false) {
2972  return Operator("norm")
2973  .SetParam("ord", ord)
2974  .SetParam("axis", axis)
2975  .SetParam("keepdims", keepdims)
2976  .SetInput("data", data)
2977  .CreateSymbol(symbol_name);
2978 }
2979 
2985 enum class TopkRetTyp {
2986  kBoth = 0,
2987  kIndices = 1,
2988  kMask = 2,
2989  kValue = 3
2990 };
2991 
2994 enum class TopkDtype {
2995  kFloat16 = 0,
2996  kFloat32 = 1,
2997  kFloat64 = 2,
2998  kInt32 = 3,
2999  kUint8 = 4
3000 };
3001 
3045 inline Symbol topk(const std::string& symbol_name,
3046  Symbol data,
3047  dmlc::optional<int> axis = dmlc::optional<int>(-1),
3048  int k = 1,
3049  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
3050  bool is_ascend = false,
3051  TopkDtype dtype = TopkDtype::kFloat32) {
3052  static const char *TopkRetTypValues[] = {
3053  "both",
3054  "indices",
3055  "mask",
3056  "value"
3057  };
3058  static const char *TopkDtypeValues[] = {
3059  "float16",
3060  "float32",
3061  "float64",
3062  "int32",
3063  "uint8"
3064  };
3065  return Operator("topk")
3066  .SetParam("axis", axis)
3067  .SetParam("k", k)
3068  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
3069  .SetParam("is_ascend", is_ascend)
3070  .SetParam("dtype", TopkDtypeValues[int(dtype)])
3071  .SetInput("data", data)
3072  .CreateSymbol(symbol_name);
3073 }
3074 
3107 inline Symbol sort(const std::string& symbol_name,
3108  Symbol data,
3109  dmlc::optional<int> axis = dmlc::optional<int>(-1),
3110  bool is_ascend = true) {
3111  return Operator("sort")
3112  .SetParam("axis", axis)
3113  .SetParam("is_ascend", is_ascend)
3114  .SetInput("data", data)
3115  .CreateSymbol(symbol_name);
3116 }
3117 
3121 enum class ArgsortDtype {
3122  kFloat16 = 0,
3123  kFloat32 = 1,
3124  kFloat64 = 2,
3125  kInt32 = 3,
3126  kUint8 = 4
3127 };
3128 
3161 inline Symbol argsort(const std::string& symbol_name,
3162  Symbol data,
3163  dmlc::optional<int> axis = dmlc::optional<int>(-1),
3164  bool is_ascend = true,
3166  static const char *ArgsortDtypeValues[] = {
3167  "float16",
3168  "float32",
3169  "float64",
3170  "int32",
3171  "uint8"
3172  };
3173  return Operator("argsort")
3174  .SetParam("axis", axis)
3175  .SetParam("is_ascend", is_ascend)
3176  .SetParam("dtype", ArgsortDtypeValues[int(dtype)])
3177  .SetInput("data", data)
3178  .CreateSymbol(symbol_name);
3179 }
3180 
3200 inline Symbol elemwise_add(const std::string& symbol_name,
3201  Symbol lhs,
3202  Symbol rhs) {
3203  return Operator("elemwise_add")
3204  .SetInput("lhs", lhs)
3205  .SetInput("rhs", rhs)
3206  .CreateSymbol(symbol_name);
3207 }
3208 
3228 inline Symbol elemwise_sub(const std::string& symbol_name,
3229  Symbol lhs,
3230  Symbol rhs) {
3231  return Operator("elemwise_sub")
3232  .SetInput("lhs", lhs)
3233  .SetInput("rhs", rhs)
3234  .CreateSymbol(symbol_name);
3235 }
3236 
3255 inline Symbol elemwise_mul(const std::string& symbol_name,
3256  Symbol lhs,
3257  Symbol rhs) {
3258  return Operator("elemwise_mul")
3259  .SetInput("lhs", lhs)
3260  .SetInput("rhs", rhs)
3261  .CreateSymbol(symbol_name);
3262 }
3263 
3275 inline Symbol elemwise_div(const std::string& symbol_name,
3276  Symbol lhs,
3277  Symbol rhs) {
3278  return Operator("elemwise_div")
3279  .SetInput("lhs", lhs)
3280  .SetInput("rhs", rhs)
3281  .CreateSymbol(symbol_name);
3282 }
3283 
3286 enum class EmbeddingDtype {
3287  kFloat16 = 0,
3288  kFloat32 = 1,
3289  kFloat64 = 2,
3290  kInt32 = 3,
3291  kInt64 = 4,
3292  kInt8 = 5,
3293  kUint8 = 6
3294 };
3295 
3359 inline Symbol Embedding(const std::string& symbol_name,
3360  Symbol data,
3361  Symbol weight,
3362  int input_dim,
3363  int output_dim,
3365  bool sparse_grad = false) {
3366  static const char *EmbeddingDtypeValues[] = {
3367  "float16",
3368  "float32",
3369  "float64",
3370  "int32",
3371  "int64",
3372  "int8",
3373  "uint8"
3374  };
3375  return Operator("Embedding")
3376  .SetParam("input_dim", input_dim)
3377  .SetParam("output_dim", output_dim)
3378  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
3379  .SetParam("sparse_grad", sparse_grad)
3380  .SetInput("data", data)
3381  .SetInput("weight", weight)
3382  .CreateSymbol(symbol_name);
3383 }
3384 
3389 enum class TakeMode {
3390  kClip = 0,
3391  kRaise = 1,
3392  kWrap = 2
3393 };
3394 
3454 inline Symbol take(const std::string& symbol_name,
3455  Symbol a,
3456  Symbol indices,
3457  int axis = 0,
3458  TakeMode mode = TakeMode::kClip) {
3459  static const char *TakeModeValues[] = {
3460  "clip",
3461  "raise",
3462  "wrap"
3463  };
3464  return Operator("take")
3465  .SetParam("axis", axis)
3466  .SetParam("mode", TakeModeValues[int(mode)])
3467  .SetInput("a", a)
3468  .SetInput("indices", indices)
3469  .CreateSymbol(symbol_name);
3470 }
3471 
3500 inline Symbol batch_take(const std::string& symbol_name,
3501  Symbol a,
3502  Symbol indices) {
3503  return Operator("batch_take")
3504  .SetInput("a", a)
3505  .SetInput("indices", indices)
3506  .CreateSymbol(symbol_name);
3507 }
3508 
3511 enum class One_hotDtype {
3512  kFloat16 = 0,
3513  kFloat32 = 1,
3514  kFloat64 = 2,
3515  kInt32 = 3,
3516  kInt64 = 4,
3517  kInt8 = 5,
3518  kUint8 = 6
3519 };
3520 
3565 inline Symbol one_hot(const std::string& symbol_name,
3566  Symbol indices,
3567  int depth,
3568  double on_value = 1,
3569  double off_value = 0,
3571  static const char *One_hotDtypeValues[] = {
3572  "float16",
3573  "float32",
3574  "float64",
3575  "int32",
3576  "int64",
3577  "int8",
3578  "uint8"
3579  };
3580  return Operator("one_hot")
3581  .SetParam("depth", depth)
3582  .SetParam("on_value", on_value)
3583  .SetParam("off_value", off_value)
3584  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
3585  .SetInput("indices", indices)
3586  .CreateSymbol(symbol_name);
3587 }
3588 
3620 inline Symbol gather_nd(const std::string& symbol_name,
3621  Symbol data,
3622  Symbol indices) {
3623  return Operator("gather_nd")
3624  .SetInput("data", data)
3625  .SetInput("indices", indices)
3626  .CreateSymbol(symbol_name);
3627 }
3628 
3680 inline Symbol scatter_nd(const std::string& symbol_name,
3681  Symbol data,
3682  Symbol indices,
3683  Shape shape) {
3684  return Operator("scatter_nd")
3685  .SetParam("shape", shape)
3686  .SetInput("data", data)
3687  .SetInput("indices", indices)
3688  .CreateSymbol(symbol_name);
3689 }
3690 
3713 inline Symbol broadcast_equal(const std::string& symbol_name,
3714  Symbol lhs,
3715  Symbol rhs) {
3716  return Operator("broadcast_equal")
3717  .SetInput("lhs", lhs)
3718  .SetInput("rhs", rhs)
3719  .CreateSymbol(symbol_name);
3720 }
3721 
3744 inline Symbol broadcast_not_equal(const std::string& symbol_name,
3745  Symbol lhs,
3746  Symbol rhs) {
3747  return Operator("broadcast_not_equal")
3748  .SetInput("lhs", lhs)
3749  .SetInput("rhs", rhs)
3750  .CreateSymbol(symbol_name);
3751 }
3752 
3775 inline Symbol broadcast_greater(const std::string& symbol_name,
3776  Symbol lhs,
3777  Symbol rhs) {
3778  return Operator("broadcast_greater")
3779  .SetInput("lhs", lhs)
3780  .SetInput("rhs", rhs)
3781  .CreateSymbol(symbol_name);
3782 }
3783 
3806 inline Symbol broadcast_greater_equal(const std::string& symbol_name,
3807  Symbol lhs,
3808  Symbol rhs) {
3809  return Operator("broadcast_greater_equal")
3810  .SetInput("lhs", lhs)
3811  .SetInput("rhs", rhs)
3812  .CreateSymbol(symbol_name);
3813 }
3814 
3837 inline Symbol broadcast_lesser(const std::string& symbol_name,
3838  Symbol lhs,
3839  Symbol rhs) {
3840  return Operator("broadcast_lesser")
3841  .SetInput("lhs", lhs)
3842  .SetInput("rhs", rhs)
3843  .CreateSymbol(symbol_name);
3844 }
3845 
3868 inline Symbol broadcast_lesser_equal(const std::string& symbol_name,
3869  Symbol lhs,
3870  Symbol rhs) {
3871  return Operator("broadcast_lesser_equal")
3872  .SetInput("lhs", lhs)
3873  .SetInput("rhs", rhs)
3874  .CreateSymbol(symbol_name);
3875 }
3876 
3899 inline Symbol broadcast_logical_and(const std::string& symbol_name,
3900  Symbol lhs,
3901  Symbol rhs) {
3902  return Operator("broadcast_logical_and")
3903  .SetInput("lhs", lhs)
3904  .SetInput("rhs", rhs)
3905  .CreateSymbol(symbol_name);
3906 }
3907 
3930 inline Symbol broadcast_logical_or(const std::string& symbol_name,
3931  Symbol lhs,
3932  Symbol rhs) {
3933  return Operator("broadcast_logical_or")
3934  .SetInput("lhs", lhs)
3935  .SetInput("rhs", rhs)
3936  .CreateSymbol(symbol_name);
3937 }
3938 
3961 inline Symbol broadcast_logical_xor(const std::string& symbol_name,
3962  Symbol lhs,
3963  Symbol rhs) {
3964  return Operator("broadcast_logical_xor")
3965  .SetInput("lhs", lhs)
3966  .SetInput("rhs", rhs)
3967  .CreateSymbol(symbol_name);
3968 }
3969 
4034 inline Symbol diag(const std::string& symbol_name,
4035  Symbol data,
4036  int k = 0,
4037  int axis1 = 0,
4038  int axis2 = 1) {
4039  return Operator("diag")
4040  .SetParam("k", k)
4041  .SetParam("axis1", axis1)
4042  .SetParam("axis2", axis2)
4043  .SetInput("data", data)
4044  .CreateSymbol(symbol_name);
4045 }
4046 
4082 inline Symbol where(const std::string& symbol_name,
4083  Symbol condition,
4084  Symbol x,
4085  Symbol y) {
4086  return Operator("where")
4087  .SetInput("condition", condition)
4088  .SetInput("x", x)
4089  .SetInput("y", y)
4090  .CreateSymbol(symbol_name);
4091 }
4092 
4119 inline Symbol smooth_l1(const std::string& symbol_name,
4120  Symbol data,
4121  mx_float scalar) {
4122  return Operator("smooth_l1")
4123  .SetParam("scalar", scalar)
4124  .SetInput("data", data)
4125  .CreateSymbol(symbol_name);
4126 }
4127 
4130 enum class Cast_storageStype {
4131  kCsr = 0,
4132  kDefault = 1,
4133  kRow_sparse = 2
4134 };
4135 
4181 inline Symbol cast_storage(const std::string& symbol_name,
4182  Symbol data,
4183  Cast_storageStype stype) {
4184  static const char *Cast_storageStypeValues[] = {
4185  "csr",
4186  "default",
4187  "row_sparse"
4188  };
4189  return Operator("cast_storage")
4190  .SetParam("stype", Cast_storageStypeValues[int(stype)])
4191  .SetInput("data", data)
4192  .CreateSymbol(symbol_name);
4193 }
4194 
4216 inline Symbol sin(const std::string& symbol_name,
4217  Symbol data) {
4218  return Operator("sin")
4219  .SetInput("data", data)
4220  .CreateSymbol(symbol_name);
4221 }
4222 
4240 inline Symbol cos(const std::string& symbol_name,
4241  Symbol data) {
4242  return Operator("cos")
4243  .SetInput("data", data)
4244  .CreateSymbol(symbol_name);
4245 }
4246 
4268 inline Symbol tan(const std::string& symbol_name,
4269  Symbol data) {
4270  return Operator("tan")
4271  .SetInput("data", data)
4272  .CreateSymbol(symbol_name);
4273 }
4274 
4297 inline Symbol arcsin(const std::string& symbol_name,
4298  Symbol data) {
4299  return Operator("arcsin")
4300  .SetInput("data", data)
4301  .CreateSymbol(symbol_name);
4302 }
4303 
4322 inline Symbol arccos(const std::string& symbol_name,
4323  Symbol data) {
4324  return Operator("arccos")
4325  .SetInput("data", data)
4326  .CreateSymbol(symbol_name);
4327 }
4328 
4350 inline Symbol arctan(const std::string& symbol_name,
4351  Symbol data) {
4352  return Operator("arctan")
4353  .SetInput("data", data)
4354  .CreateSymbol(symbol_name);
4355 }
4356 
4376 inline Symbol degrees(const std::string& symbol_name,
4377  Symbol data) {
4378  return Operator("degrees")
4379  .SetInput("data", data)
4380  .CreateSymbol(symbol_name);
4381 }
4382 
4402 inline Symbol radians(const std::string& symbol_name,
4403  Symbol data) {
4404  return Operator("radians")
4405  .SetInput("data", data)
4406  .CreateSymbol(symbol_name);
4407 }
4408 
4428 inline Symbol sinh(const std::string& symbol_name,
4429  Symbol data) {
4430  return Operator("sinh")
4431  .SetInput("data", data)
4432  .CreateSymbol(symbol_name);
4433 }
4434 
4450 inline Symbol cosh(const std::string& symbol_name,
4451  Symbol data) {
4452  return Operator("cosh")
4453  .SetInput("data", data)
4454  .CreateSymbol(symbol_name);
4455 }
4456 
4476 inline Symbol tanh(const std::string& symbol_name,
4477  Symbol data) {
4478  return Operator("tanh")
4479  .SetInput("data", data)
4480  .CreateSymbol(symbol_name);
4481 }
4482 
4500 inline Symbol arcsinh(const std::string& symbol_name,
4501  Symbol data) {
4502  return Operator("arcsinh")
4503  .SetInput("data", data)
4504  .CreateSymbol(symbol_name);
4505 }
4506 
4520 inline Symbol arccosh(const std::string& symbol_name,
4521  Symbol data) {
4522  return Operator("arccosh")
4523  .SetInput("data", data)
4524  .CreateSymbol(symbol_name);
4525 }
4526 
4544 inline Symbol arctanh(const std::string& symbol_name,
4545  Symbol data) {
4546  return Operator("arctanh")
4547  .SetInput("data", data)
4548  .CreateSymbol(symbol_name);
4549 }
4550 
4553 enum class PoolingPoolType {
4554  kAvg = 0,
4555  kLp = 1,
4556  kMax = 2,
4557  kSum = 3
4558 };
4559 
4563  kFull = 0,
4564  kSame = 1,
4565  kValid = 2
4566 };
4567 
4636 inline Symbol Pooling(const std::string& symbol_name,
4637  Symbol data,
4638  Shape kernel = Shape(),
4640  bool global_pool = false,
4641  bool cudnn_off = false,
4643  Shape stride = Shape(),
4644  Shape pad = Shape(),
4645  dmlc::optional<int> p_value = dmlc::optional<int>(),
4646  dmlc::optional<bool> count_include_pad = dmlc::optional<bool>()) {
4647  static const char *PoolingPoolTypeValues[] = {
4648  "avg",
4649  "lp",
4650  "max",
4651  "sum"
4652  };
4653  static const char *PoolingPoolingConventionValues[] = {
4654  "full",
4655  "same",
4656  "valid"
4657  };
4658  return Operator("Pooling")
4659  .SetParam("kernel", kernel)
4660  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
4661  .SetParam("global_pool", global_pool)
4662  .SetParam("cudnn_off", cudnn_off)
4663  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
4664  .SetParam("stride", stride)
4665  .SetParam("pad", pad)
4666  .SetParam("p_value", p_value)
4667  .SetParam("count_include_pad", count_include_pad)
4668  .SetInput("data", data)
4669  .CreateSymbol(symbol_name);
4670 }
4671 
4704 inline Symbol softmax(const std::string& symbol_name,
4705  Symbol data,
4706  int axis = -1,
4707  dmlc::optional<double> temperature = dmlc::optional<double>()) {
4708  return Operator("softmax")
4709  .SetParam("axis", axis)
4710  .SetParam("temperature", temperature)
4711  .SetInput("data", data)
4712  .CreateSymbol(symbol_name);
4713 }
4714 
4748 inline Symbol softmin(const std::string& symbol_name,
4749  Symbol data,
4750  int axis = -1,
4751  dmlc::optional<double> temperature = dmlc::optional<double>()) {
4752  return Operator("softmin")
4753  .SetParam("axis", axis)
4754  .SetParam("temperature", temperature)
4755  .SetInput("data", data)
4756  .CreateSymbol(symbol_name);
4757 }
4758 
4782 inline Symbol log_softmax(const std::string& symbol_name,
4783  Symbol data,
4784  int axis = -1,
4785  dmlc::optional<double> temperature = dmlc::optional<double>()) {
4786  return Operator("log_softmax")
4787  .SetParam("axis", axis)
4788  .SetParam("temperature", temperature)
4789  .SetInput("data", data)
4790  .CreateSymbol(symbol_name);
4791 }
4792 
4796  kNone = 0,
4797  kFastest = 1,
4798  kLimited_workspace = 2,
4799  kOff = 3
4800 };
4801 
4805  kNone = 0,
4806  kNCDHW = 1,
4807  kNCHW = 2,
4808  kNCW = 3,
4809  kNDHWC = 4,
4810  kNHWC = 5
4811 };
4812 
4842 inline Symbol Deconvolution(const std::string& symbol_name,
4843  Symbol data,
4844  Symbol weight,
4845  Symbol bias,
4846  Shape kernel,
4847  uint32_t num_filter,
4848  Shape stride = Shape(),
4849  Shape dilate = Shape(),
4850  Shape pad = Shape(),
4851  Shape adj = Shape(),
4852  Shape target_shape = Shape(),
4853  uint32_t num_group = 1,
4854  uint64_t workspace = 512,
4855  bool no_bias = true,
4857  bool cudnn_off = false,
4859  static const char *DeconvolutionCudnnTuneValues[] = {
4860  "None",
4861  "fastest",
4862  "limited_workspace",
4863  "off"
4864  };
4865  static const char *DeconvolutionLayoutValues[] = {
4866  "None",
4867  "NCDHW",
4868  "NCHW",
4869  "NCW",
4870  "NDHWC",
4871  "NHWC"
4872  };
4873  return Operator("Deconvolution")
4874  .SetParam("kernel", kernel)
4875  .SetParam("num_filter", num_filter)
4876  .SetParam("stride", stride)
4877  .SetParam("dilate", dilate)
4878  .SetParam("pad", pad)
4879  .SetParam("adj", adj)
4880  .SetParam("target_shape", target_shape)
4881  .SetParam("num_group", num_group)
4882  .SetParam("workspace", workspace)
4883  .SetParam("no_bias", no_bias)
4884  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
4885  .SetParam("cudnn_off", cudnn_off)
4886  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
4887  .SetInput("data", data)
4888  .SetInput("weight", weight)
4889  .SetInput("bias", bias)
4890  .CreateSymbol(symbol_name);
4891 }
4892 
4895 enum class ActivationActType {
4896  kRelu = 0,
4897  kSigmoid = 1,
4898  kSoftrelu = 2,
4899  kSoftsign = 3,
4900  kTanh = 4
4901 };
4902 
4922 inline Symbol Activation(const std::string& symbol_name,
4923  Symbol data,
4924  ActivationActType act_type) {
4925  static const char *ActivationActTypeValues[] = {
4926  "relu",
4927  "sigmoid",
4928  "softrelu",
4929  "softsign",
4930  "tanh"
4931  };
4932  return Operator("Activation")
4933  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
4934  .SetInput("data", data)
4935  .CreateSymbol(symbol_name);
4936 }
4937 
5006 inline Symbol BatchNorm(const std::string& symbol_name,
5007  Symbol data,
5008  Symbol gamma,
5009  Symbol beta,
5010  Symbol moving_mean,
5011  Symbol moving_var,
5012  double eps = 0.001,
5013  mx_float momentum = 0.9,
5014  bool fix_gamma = true,
5015  bool use_global_stats = false,
5016  bool output_mean_var = false,
5017  int axis = 1,
5018  bool cudnn_off = false) {
5019  return Operator("BatchNorm")
5020  .SetParam("eps", eps)
5021  .SetParam("momentum", momentum)
5022  .SetParam("fix_gamma", fix_gamma)
5023  .SetParam("use_global_stats", use_global_stats)
5024  .SetParam("output_mean_var", output_mean_var)
5025  .SetParam("axis", axis)
5026  .SetParam("cudnn_off", cudnn_off)
5027  .SetInput("data", data)
5028  .SetInput("gamma", gamma)
5029  .SetInput("beta", beta)
5030  .SetInput("moving_mean", moving_mean)
5031  .SetInput("moving_var", moving_var)
5032  .CreateSymbol(symbol_name);
5033 }
5034 
5041 enum class CTCLossBlankLabel {
5042  kFirst = 0,
5043  kLast = 1
5044 };
5045 
5112 inline Symbol CTCLoss(const std::string& symbol_name,
5113  Symbol data,
5114  Symbol label,
5115  Symbol data_lengths,
5116  Symbol label_lengths,
5117  bool use_data_lengths = false,
5118  bool use_label_lengths = false,
5120  static const char *CTCLossBlankLabelValues[] = {
5121  "first",
5122  "last"
5123  };
5124  return Operator("CTCLoss")
5125  .SetParam("use_data_lengths", use_data_lengths)
5126  .SetParam("use_label_lengths", use_label_lengths)
5127  .SetParam("blank_label", CTCLossBlankLabelValues[int(blank_label)])
5128  .SetInput("data", data)
5129  .SetInput("label", label)
5130  .SetInput("data_lengths", data_lengths)
5131  .SetInput("label_lengths", label_lengths)
5132  .CreateSymbol(symbol_name);
5133 }
5134 
5138  kNone = 0,
5139  kFastest = 1,
5140  kLimited_workspace = 2,
5141  kOff = 3
5142 };
5143 
5147 enum class ConvolutionLayout {
5148  kNone = 0,
5149  kNCDHW = 1,
5150  kNCHW = 2,
5151  kNCW = 3,
5152  kNDHWC = 4,
5153  kNHWC = 5
5154 };
5155 
5253 inline Symbol Convolution(const std::string& symbol_name,
5254  Symbol data,
5255  Symbol weight,
5256  Symbol bias,
5257  Shape kernel,
5258  uint32_t num_filter,
5259  Shape stride = Shape(),
5260  Shape dilate = Shape(),
5261  Shape pad = Shape(),
5262  uint32_t num_group = 1,
5263  uint64_t workspace = 1024,
5264  bool no_bias = false,
5266  bool cudnn_off = false,
5268  static const char *ConvolutionCudnnTuneValues[] = {
5269  "None",
5270  "fastest",
5271  "limited_workspace",
5272  "off"
5273  };
5274  static const char *ConvolutionLayoutValues[] = {
5275  "None",
5276  "NCDHW",
5277  "NCHW",
5278  "NCW",
5279  "NDHWC",
5280  "NHWC"
5281  };
5282  return Operator("Convolution")
5283  .SetParam("kernel", kernel)
5284  .SetParam("num_filter", num_filter)
5285  .SetParam("stride", stride)
5286  .SetParam("dilate", dilate)
5287  .SetParam("pad", pad)
5288  .SetParam("num_group", num_group)
5289  .SetParam("workspace", workspace)
5290  .SetParam("no_bias", no_bias)
5291  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
5292  .SetParam("cudnn_off", cudnn_off)
5293  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
5294  .SetInput("data", data)
5295  .SetInput("weight", weight)
5296  .SetInput("bias", bias)
5297  .CreateSymbol(symbol_name);
5298 }
5299 
5303  kBilinear = 0,
5304  kNearest = 1
5305 };
5306 
5311  kConcat = 0,
5312  kSum = 1
5313 };
5314 
5330 inline Symbol UpSampling(const std::string& symbol_name,
5331  const std::vector<Symbol>& data,
5332  int scale,
5333  UpSamplingSampleType sample_type,
5334  int num_args,
5335  int num_filter = 0,
5337  uint64_t workspace = 512) {
5338  static const char *UpSamplingSampleTypeValues[] = {
5339  "bilinear",
5340  "nearest"
5341  };
5342  static const char *UpSamplingMultiInputModeValues[] = {
5343  "concat",
5344  "sum"
5345  };
5346  return Operator("UpSampling")
5347  .SetParam("scale", scale)
5348  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
5349  .SetParam("num_args", num_args)
5350  .SetParam("num_filter", num_filter)
5351  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
5352  .SetParam("workspace", workspace)
5353 (data)
5354  .CreateSymbol(symbol_name);
5355 }
5356 
5403 inline Symbol Concat(const std::string& symbol_name,
5404  const std::vector<Symbol>& data,
5405  int num_args,
5406  int dim = 1) {
5407  return Operator("Concat")
5408  .SetParam("num_args", num_args)
5409  .SetParam("dim", dim)
5410 (data)
5411  .CreateSymbol(symbol_name);
5412 }
5413 
5452 inline Symbol LayerNorm(const std::string& symbol_name,
5453  Symbol data,
5454  Symbol gamma,
5455  Symbol beta,
5456  int axis = -1,
5457  mx_float eps = 1e-05,
5458  bool output_mean_var = false) {
5459  return Operator("LayerNorm")
5460  .SetParam("axis", axis)
5461  .SetParam("eps", eps)
5462  .SetParam("output_mean_var", output_mean_var)
5463  .SetInput("data", data)
5464  .SetInput("gamma", gamma)
5465  .SetInput("beta", beta)
5466  .CreateSymbol(symbol_name);
5467 }
5468 
5496 inline Symbol LRN(const std::string& symbol_name,
5497  Symbol data,
5498  uint32_t nsize,
5499  mx_float alpha = 0.0001,
5500  mx_float beta = 0.75,
5501  mx_float knorm = 2) {
5502  return Operator("LRN")
5503  .SetParam("nsize", nsize)
5504  .SetParam("alpha", alpha)
5505  .SetParam("beta", beta)
5506  .SetParam("knorm", knorm)
5507  .SetInput("data", data)
5508  .CreateSymbol(symbol_name);
5509 }
5510 
5513 enum class DropoutMode {
5514  kAlways = 0,
5515  kTraining = 1
5516 };
5517 
5558 inline Symbol Dropout(const std::string& symbol_name,
5559  Symbol data,
5560  mx_float p = 0.5,
5562  Shape axes = Shape()) {
5563  static const char *DropoutModeValues[] = {
5564  "always",
5565  "training"
5566  };
5567  return Operator("Dropout")
5568  .SetParam("p", p)
5569  .SetParam("mode", DropoutModeValues[int(mode)])
5570  .SetParam("axes", axes)
5571  .SetInput("data", data)
5572  .CreateSymbol(symbol_name);
5573 }
5574 
5579  kChannel = 0,
5580  kInstance = 1
5581 };
5582 
5616 inline Symbol SoftmaxActivation(const std::string& symbol_name,
5617  Symbol data,
5619  static const char *SoftmaxActivationModeValues[] = {
5620  "channel",
5621  "instance"
5622  };
5623  return Operator("SoftmaxActivation")
5624  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
5625  .SetInput("data", data)
5626  .CreateSymbol(symbol_name);
5627 }
5628 
5672 inline Symbol FullyConnected(const std::string& symbol_name,
5673  Symbol data,
5674  Symbol weight,
5675  Symbol bias,
5676  int num_hidden,
5677  bool no_bias = false,
5678  bool flatten = true) {
5679  return Operator("FullyConnected")
5680  .SetParam("num_hidden", num_hidden)
5681  .SetParam("no_bias", no_bias)
5682  .SetParam("flatten", flatten)
5683  .SetInput("data", data)
5684  .SetInput("weight", weight)
5685  .SetInput("bias", bias)
5686  .CreateSymbol(symbol_name);
5687 }
5688 
5692 enum class PadMode {
5693  kConstant = 0,
5694  kEdge = 1,
5695  kReflect = 2
5696 };
5697 
5794 inline Symbol Pad(const std::string& symbol_name,
5795  Symbol data,
5796  PadMode mode,
5797  Shape pad_width,
5798  double constant_value = 0) {
5799  static const char *PadModeValues[] = {
5800  "constant",
5801  "edge",
5802  "reflect"
5803  };
5804  return Operator("Pad")
5805  .SetParam("mode", PadModeValues[int(mode)])
5806  .SetParam("pad_width", pad_width)
5807  .SetParam("constant_value", constant_value)
5808  .SetInput("data", data)
5809  .CreateSymbol(symbol_name);
5810 }
5811 
5814 enum class LeakyReLUActType {
5815  kElu = 0,
5816  kLeaky = 1,
5817  kPrelu = 2,
5818  kRrelu = 3,
5819  kSelu = 4
5820 };
5821 
5852 inline Symbol LeakyReLU(const std::string& symbol_name,
5853  Symbol data,
5854  Symbol gamma,
5856  mx_float slope = 0.25,
5857  mx_float lower_bound = 0.125,
5858  mx_float upper_bound = 0.334) {
5859  static const char *LeakyReLUActTypeValues[] = {
5860  "elu",
5861  "leaky",
5862  "prelu",
5863  "rrelu",
5864  "selu"
5865  };
5866  return Operator("LeakyReLU")
5867  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
5868  .SetParam("slope", slope)
5869  .SetParam("lower_bound", lower_bound)
5870  .SetParam("upper_bound", upper_bound)
5871  .SetInput("data", data)
5872  .SetInput("gamma", gamma)
5873  .CreateSymbol(symbol_name);
5874 }
5875 
5904 inline Symbol SwapAxis(const std::string& symbol_name,
5905  Symbol data,
5906  uint32_t dim1 = 0,
5907  uint32_t dim2 = 0) {
5908  return Operator("SwapAxis")
5909  .SetParam("dim1", dim1)
5910  .SetParam("dim2", dim2)
5911  .SetInput("data", data)
5912  .CreateSymbol(symbol_name);
5913 }
5914 
5975 inline Symbol BatchNorm_v1(const std::string& symbol_name,
5976  Symbol data,
5977  Symbol gamma,
5978  Symbol beta,
5979  mx_float eps = 0.001,
5980  mx_float momentum = 0.9,
5981  bool fix_gamma = true,
5982  bool use_global_stats = false,
5983  bool output_mean_var = false) {
5984  return Operator("BatchNorm_v1")
5985  .SetParam("eps", eps)
5986  .SetParam("momentum", momentum)
5987  .SetParam("fix_gamma", fix_gamma)
5988  .SetParam("use_global_stats", use_global_stats)
5989  .SetParam("output_mean_var", output_mean_var)
5990  .SetInput("data", data)
5991  .SetInput("gamma", gamma)
5992  .SetInput("beta", beta)
5993  .CreateSymbol(symbol_name);
5994 }
5995 
6033 inline Symbol softmax_cross_entropy(const std::string& symbol_name,
6034  Symbol data,
6035  Symbol label) {
6036  return Operator("softmax_cross_entropy")
6037  .SetInput("data", data)
6038  .SetInput("label", label)
6039  .CreateSymbol(symbol_name);
6040 }
6041 
6071 inline Symbol LinearRegressionOutput(const std::string& symbol_name,
6072  Symbol data,
6073  Symbol label,
6074  mx_float grad_scale = 1) {
6075  return Operator("LinearRegressionOutput")
6076  .SetParam("grad_scale", grad_scale)
6077  .SetInput("data", data)
6078  .SetInput("label", label)
6079  .CreateSymbol(symbol_name);
6080 }
6081 
6112 inline Symbol MAERegressionOutput(const std::string& symbol_name,
6113  Symbol data,
6114  Symbol label,
6115  mx_float grad_scale = 1) {
6116  return Operator("MAERegressionOutput")
6117  .SetParam("grad_scale", grad_scale)
6118  .SetInput("data", data)
6119  .SetInput("label", label)
6120  .CreateSymbol(symbol_name);
6121 }
6122 
6159 inline Symbol LogisticRegressionOutput(const std::string& symbol_name,
6160  Symbol data,
6161  Symbol label,
6162  mx_float grad_scale = 1) {
6163  return Operator("LogisticRegressionOutput")
6164  .SetParam("grad_scale", grad_scale)
6165  .SetInput("data", data)
6166  .SetInput("label", label)
6167  .CreateSymbol(symbol_name);
6168 }
6169 
6179 inline Symbol IdentityAttachKLSparseReg(const std::string& symbol_name,
6180  Symbol data,
6181  mx_float sparseness_target = 0.1,
6182  mx_float penalty = 0.001,
6183  mx_float momentum = 0.9) {
6184  return Operator("IdentityAttachKLSparseReg")
6185  .SetParam("sparseness_target", sparseness_target)
6186  .SetParam("penalty", penalty)
6187  .SetParam("momentum", momentum)
6188  .SetInput("data", data)
6189  .CreateSymbol(symbol_name);
6190 }
6191 
6220 inline Symbol signsgd_update(const std::string& symbol_name,
6221  Symbol weight,
6222  Symbol grad,
6223  mx_float lr,
6224  mx_float wd = 0,
6225  mx_float rescale_grad = 1,
6226  mx_float clip_gradient = -1) {
6227  return Operator("signsgd_update")
6228  .SetParam("lr", lr)
6229  .SetParam("wd", wd)
6230  .SetParam("rescale_grad", rescale_grad)
6231  .SetParam("clip_gradient", clip_gradient)
6232  .SetInput("weight", weight)
6233  .SetInput("grad", grad)
6234  .CreateSymbol(symbol_name);
6235 }
6236 
6271 inline Symbol signum_update(const std::string& symbol_name,
6272  Symbol weight,
6273  Symbol grad,
6274  Symbol mom,
6275  mx_float lr,
6276  mx_float momentum = 0,
6277  mx_float wd = 0,
6278  mx_float rescale_grad = 1,
6279  mx_float clip_gradient = -1,
6280  mx_float wd_lh = 0) {
6281  return Operator("signum_update")
6282  .SetParam("lr", lr)
6283  .SetParam("momentum", momentum)
6284  .SetParam("wd", wd)
6285  .SetParam("rescale_grad", rescale_grad)
6286  .SetParam("clip_gradient", clip_gradient)
6287  .SetParam("wd_lh", wd_lh)
6288  .SetInput("weight", weight)
6289  .SetInput("grad", grad)
6290  .SetInput("mom", mom)
6291  .CreateSymbol(symbol_name);
6292 }
6293 
6322 inline Symbol sgd_update(const std::string& symbol_name,
6323  Symbol weight,
6324  Symbol grad,
6325  mx_float lr,
6326  mx_float wd = 0,
6327  mx_float rescale_grad = 1,
6328  mx_float clip_gradient = -1,
6329  bool lazy_update = true) {
6330  return Operator("sgd_update")
6331  .SetParam("lr", lr)
6332  .SetParam("wd", wd)
6333  .SetParam("rescale_grad", rescale_grad)
6334  .SetParam("clip_gradient", clip_gradient)
6335  .SetParam("lazy_update", lazy_update)
6336  .SetInput("weight", weight)
6337  .SetInput("grad", grad)
6338  .CreateSymbol(symbol_name);
6339 }
6340 
6385 inline Symbol sgd_mom_update(const std::string& symbol_name,
6386  Symbol weight,
6387  Symbol grad,
6388  Symbol mom,
6389  mx_float lr,
6390  mx_float momentum = 0,
6391  mx_float wd = 0,
6392  mx_float rescale_grad = 1,
6393  mx_float clip_gradient = -1,
6394  bool lazy_update = true) {
6395  return Operator("sgd_mom_update")
6396  .SetParam("lr", lr)
6397  .SetParam("momentum", momentum)
6398  .SetParam("wd", wd)
6399  .SetParam("rescale_grad", rescale_grad)
6400  .SetParam("clip_gradient", clip_gradient)
6401  .SetParam("lazy_update", lazy_update)
6402  .SetInput("weight", weight)
6403  .SetInput("grad", grad)
6404  .SetInput("mom", mom)
6405  .CreateSymbol(symbol_name);
6406 }
6407 
6423 inline Symbol mp_sgd_update(const std::string& symbol_name,
6424  Symbol weight,
6425  Symbol grad,
6426  Symbol weight32,
6427  mx_float lr,
6428  mx_float wd = 0,
6429  mx_float rescale_grad = 1,
6430  mx_float clip_gradient = -1,
6431  bool lazy_update = true) {
6432  return Operator("mp_sgd_update")
6433  .SetParam("lr", lr)
6434  .SetParam("wd", wd)
6435  .SetParam("rescale_grad", rescale_grad)
6436  .SetParam("clip_gradient", clip_gradient)
6437  .SetParam("lazy_update", lazy_update)
6438  .SetInput("weight", weight)
6439  .SetInput("grad", grad)
6440  .SetInput("weight32", weight32)
6441  .CreateSymbol(symbol_name);
6442 }
6443 
6461 inline Symbol mp_sgd_mom_update(const std::string& symbol_name,
6462  Symbol weight,
6463  Symbol grad,
6464  Symbol mom,
6465  Symbol weight32,
6466  mx_float lr,
6467  mx_float momentum = 0,
6468  mx_float wd = 0,
6469  mx_float rescale_grad = 1,
6470  mx_float clip_gradient = -1,
6471  bool lazy_update = true) {
6472  return Operator("mp_sgd_mom_update")
6473  .SetParam("lr", lr)
6474  .SetParam("momentum", momentum)
6475  .SetParam("wd", wd)
6476  .SetParam("rescale_grad", rescale_grad)
6477  .SetParam("clip_gradient", clip_gradient)
6478  .SetParam("lazy_update", lazy_update)
6479  .SetInput("weight", weight)
6480  .SetInput("grad", grad)
6481  .SetInput("mom", mom)
6482  .SetInput("weight32", weight32)
6483  .CreateSymbol(symbol_name);
6484 }
6485 
6521 inline Symbol ftml_update(const std::string& symbol_name,
6522  Symbol weight,
6523  Symbol grad,
6524  Symbol d,
6525  Symbol v,
6526  Symbol z,
6527  mx_float lr,
6528  int t,
6529  mx_float beta1 = 0.6,
6530  mx_float beta2 = 0.999,
6531  double epsilon = 1e-08,
6532  mx_float wd = 0,
6533  mx_float rescale_grad = 1,
6534  mx_float clip_grad = -1) {
6535  return Operator("ftml_update")
6536  .SetParam("lr", lr)
6537  .SetParam("t", t)
6538  .SetParam("beta1", beta1)
6539  .SetParam("beta2", beta2)
6540  .SetParam("epsilon", epsilon)
6541  .SetParam("wd", wd)
6542  .SetParam("rescale_grad", rescale_grad)
6543  .SetParam("clip_grad", clip_grad)
6544  .SetInput("weight", weight)
6545  .SetInput("grad", grad)
6546  .SetInput("d", d)
6547  .SetInput("v", v)
6548  .SetInput("z", z)
6549  .CreateSymbol(symbol_name);
6550 }
6551 
6601 inline Symbol adam_update(const std::string& symbol_name,
6602  Symbol weight,
6603  Symbol grad,
6604  Symbol mean,
6605  Symbol var,
6606  mx_float lr,
6607  mx_float beta1 = 0.9,
6608  mx_float beta2 = 0.999,
6609  mx_float epsilon = 1e-08,
6610  mx_float wd = 0,
6611  mx_float rescale_grad = 1,
6612  mx_float clip_gradient = -1,
6613  bool lazy_update = true) {
6614  return Operator("adam_update")
6615  .SetParam("lr", lr)
6616  .SetParam("beta1", beta1)
6617  .SetParam("beta2", beta2)
6618  .SetParam("epsilon", epsilon)
6619  .SetParam("wd", wd)
6620  .SetParam("rescale_grad", rescale_grad)
6621  .SetParam("clip_gradient", clip_gradient)
6622  .SetParam("lazy_update", lazy_update)
6623  .SetInput("weight", weight)
6624  .SetInput("grad", grad)
6625  .SetInput("mean", mean)
6626  .SetInput("var", var)
6627  .CreateSymbol(symbol_name);
6628 }
6629 
6683 inline Symbol rmsprop_update(const std::string& symbol_name,
6684  Symbol weight,
6685  Symbol grad,
6686  Symbol n,
6687  mx_float lr,
6688  mx_float gamma1 = 0.95,
6689  mx_float epsilon = 1e-08,
6690  mx_float wd = 0,
6691  mx_float rescale_grad = 1,
6692  mx_float clip_gradient = -1,
6693  mx_float clip_weights = -1) {
6694  return Operator("rmsprop_update")
6695  .SetParam("lr", lr)
6696  .SetParam("gamma1", gamma1)
6697  .SetParam("epsilon", epsilon)
6698  .SetParam("wd", wd)
6699  .SetParam("rescale_grad", rescale_grad)
6700  .SetParam("clip_gradient", clip_gradient)
6701  .SetParam("clip_weights", clip_weights)
6702  .SetInput("weight", weight)
6703  .SetInput("grad", grad)
6704  .SetInput("n", n)
6705  .CreateSymbol(symbol_name);
6706 }
6707 
6753 inline Symbol rmspropalex_update(const std::string& symbol_name,
6754  Symbol weight,
6755  Symbol grad,
6756  Symbol n,
6757  Symbol g,
6758  Symbol delta,
6759  mx_float lr,
6760  mx_float gamma1 = 0.95,
6761  mx_float gamma2 = 0.9,
6762  mx_float epsilon = 1e-08,
6763  mx_float wd = 0,
6764  mx_float rescale_grad = 1,
6765  mx_float clip_gradient = -1,
6766  mx_float clip_weights = -1) {
6767  return Operator("rmspropalex_update")
6768  .SetParam("lr", lr)
6769  .SetParam("gamma1", gamma1)
6770  .SetParam("gamma2", gamma2)
6771  .SetParam("epsilon", epsilon)
6772  .SetParam("wd", wd)
6773  .SetParam("rescale_grad", rescale_grad)
6774  .SetParam("clip_gradient", clip_gradient)
6775  .SetParam("clip_weights", clip_weights)
6776  .SetInput("weight", weight)
6777  .SetInput("grad", grad)
6778  .SetInput("n", n)
6779  .SetInput("g", g)
6780  .SetInput("delta", delta)
6781  .CreateSymbol(symbol_name);
6782 }
6783 
6823 inline Symbol ftrl_update(const std::string& symbol_name,
6824  Symbol weight,
6825  Symbol grad,
6826  Symbol z,
6827  Symbol n,
6828  mx_float lr,
6829  mx_float lamda1 = 0.01,
6830  mx_float beta = 1,
6831  mx_float wd = 0,
6832  mx_float rescale_grad = 1,
6833  mx_float clip_gradient = -1) {
6834  return Operator("ftrl_update")
6835  .SetParam("lr", lr)
6836  .SetParam("lamda1", lamda1)
6837  .SetParam("beta", beta)
6838  .SetParam("wd", wd)
6839  .SetParam("rescale_grad", rescale_grad)
6840  .SetParam("clip_gradient", clip_gradient)
6841  .SetInput("weight", weight)
6842  .SetInput("grad", grad)
6843  .SetInput("z", z)
6844  .SetInput("n", n)
6845  .CreateSymbol(symbol_name);
6846 }
6847 
6919 inline Symbol SliceChannel(const std::string& symbol_name,
6920  Symbol data,
6921  int num_outputs,
6922  int axis = 1,
6923  bool squeeze_axis = false) {
6924  return Operator("SliceChannel")
6925  .SetParam("num_outputs", num_outputs)
6926  .SetParam("axis", axis)
6927  .SetParam("squeeze_axis", squeeze_axis)
6928  .SetInput("data", data)
6929  .CreateSymbol(symbol_name);
6930 }
6931 
6982 inline Symbol InstanceNorm(const std::string& symbol_name,
6983  Symbol data,
6984  Symbol gamma,
6985  Symbol beta,
6986  mx_float eps = 0.001) {
6987  return Operator("InstanceNorm")
6988  .SetParam("eps", eps)
6989  .SetInput("data", data)
6990  .SetInput("gamma", gamma)
6991  .SetInput("beta", beta)
6992  .CreateSymbol(symbol_name);
6993 }
6994 
6999  kAffine = 0,
7000  kWarp = 1
7001 };
7002 
7013 inline Symbol GridGenerator(const std::string& symbol_name,
7014  Symbol data,
7015  GridGeneratorTransformType transform_type,
7016  Shape target_shape = Shape(0,0)) {
7017  static const char *GridGeneratorTransformTypeValues[] = {
7018  "affine",
7019  "warp"
7020  };
7021  return Operator("GridGenerator")
7022  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
7023  .SetParam("target_shape", target_shape)
7024  .SetInput("data", data)
7025  .CreateSymbol(symbol_name);
7026 }
7027 
7031  kAvg = 0,
7032  kMax = 1,
7033  kSum = 2
7034 };
7035 
7039  kFull = 0,
7040  kValid = 1
7041 };
7042 
7094 inline Symbol Pooling_v1(const std::string& symbol_name,
7095  Symbol data,
7096  Shape kernel = Shape(),
7098  bool global_pool = false,
7100  Shape stride = Shape(),
7101  Shape pad = Shape()) {
7102  static const char *Pooling_v1PoolTypeValues[] = {
7103  "avg",
7104  "max",
7105  "sum"
7106  };
7107  static const char *Pooling_v1PoolingConventionValues[] = {
7108  "full",
7109  "valid"
7110  };
7111  return Operator("Pooling_v1")
7112  .SetParam("kernel", kernel)
7113  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
7114  .SetParam("global_pool", global_pool)
7115  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
7116  .SetParam("stride", stride)
7117  .SetParam("pad", pad)
7118  .SetInput("data", data)
7119  .CreateSymbol(symbol_name);
7120 }
7121 
7124 enum class RNNMode {
7125  kGru = 0,
7126  kLstm = 1,
7127  kRnn_relu = 2,
7128  kRnn_tanh = 3
7129 };
7130 
7202 inline Symbol RNN(const std::string& symbol_name,
7203  Symbol data,
7204  Symbol parameters,
7205  Symbol state,
7206  Symbol state_cell,
7207  uint32_t state_size,
7208  uint32_t num_layers,
7209  RNNMode mode,
7210  bool bidirectional = false,
7211  mx_float p = 0,
7212  bool state_outputs = false,
7213  dmlc::optional<int> projection_size = dmlc::optional<int>(),
7214  dmlc::optional<double> lstm_state_clip_min = dmlc::optional<double>(),
7215  dmlc::optional<double> lstm_state_clip_max = dmlc::optional<double>(),
7216  bool lstm_state_clip_nan = false) {
7217  static const char *RNNModeValues[] = {
7218  "gru",
7219  "lstm",
7220  "rnn_relu",
7221  "rnn_tanh"
7222  };
7223  return Operator("RNN")
7224  .SetParam("state_size", state_size)
7225  .SetParam("num_layers", num_layers)
7226  .SetParam("mode", RNNModeValues[int(mode)])
7227  .SetParam("bidirectional", bidirectional)
7228  .SetParam("p", p)
7229  .SetParam("state_outputs", state_outputs)
7230  .SetParam("projection_size", projection_size)
7231  .SetParam("lstm_state_clip_min", lstm_state_clip_min)
7232  .SetParam("lstm_state_clip_max", lstm_state_clip_max)
7233  .SetParam("lstm_state_clip_nan", lstm_state_clip_nan)
7234  .SetInput("data", data)
7235  .SetInput("parameters", parameters)
7236  .SetInput("state", state)
7237  .SetInput("state_cell", state_cell)
7238  .CreateSymbol(symbol_name);
7239 }
7240 
7251  kNone = 0,
7252  kFastest = 1,
7253  kLimited_workspace = 2,
7254  kOff = 3
7255 };
7256 
7261  kNone = 0,
7262  kNCDHW = 1,
7263  kNCHW = 2,
7264  kNDHWC = 3,
7265  kNHWC = 4
7266 };
7267 
7298 inline Symbol Convolution_v1(const std::string& symbol_name,
7299  Symbol data,
7300  Symbol weight,
7301  Symbol bias,
7302  Shape kernel,
7303  uint32_t num_filter,
7304  Shape stride = Shape(),
7305  Shape dilate = Shape(),
7306  Shape pad = Shape(),
7307  uint32_t num_group = 1,
7308  uint64_t workspace = 1024,
7309  bool no_bias = false,
7311  bool cudnn_off = false,
7313  static const char *Convolution_v1CudnnTuneValues[] = {
7314  "None",
7315  "fastest",
7316  "limited_workspace",
7317  "off"
7318  };
7319  static const char *Convolution_v1LayoutValues[] = {
7320  "None",
7321  "NCDHW",
7322  "NCHW",
7323  "NDHWC",
7324  "NHWC"
7325  };
7326  return Operator("Convolution_v1")
7327  .SetParam("kernel", kernel)
7328  .SetParam("num_filter", num_filter)
7329  .SetParam("stride", stride)
7330  .SetParam("dilate", dilate)
7331  .SetParam("pad", pad)
7332  .SetParam("num_group", num_group)
7333  .SetParam("workspace", workspace)
7334  .SetParam("no_bias", no_bias)
7335  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
7336  .SetParam("cudnn_off", cudnn_off)
7337  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
7338  .SetInput("data", data)
7339  .SetInput("weight", weight)
7340  .SetInput("bias", bias)
7341  .CreateSymbol(symbol_name);
7342 }
7343 
7364 inline Symbol Crop(const std::string& symbol_name,
7365  const std::vector<Symbol>& data,
7366  int num_args,
7367  Shape offset = Shape(0,0),
7368  Shape h_w = Shape(0,0),
7369  bool center_crop = false) {
7370  return Operator("Crop")
7371  .SetParam("num_args", num_args)
7372  .SetParam("offset", offset)
7373  .SetParam("h_w", h_w)
7374  .SetParam("center_crop", center_crop)
7375 (data)
7376  .CreateSymbol(symbol_name);
7377 }
7378 
7455 inline Symbol SequenceReverse(const std::string& symbol_name,
7456  Symbol data,
7457  Symbol sequence_length,
7458  bool use_sequence_length = false,
7459  int axis = 0) {
7460  return Operator("SequenceReverse")
7461  .SetParam("use_sequence_length", use_sequence_length)
7462  .SetParam("axis", axis)
7463  .SetInput("data", data)
7464  .SetInput("sequence_length", sequence_length)
7465  .CreateSymbol(symbol_name);
7466 }
7467 
7471  kAffine = 0
7472 };
7473 
7477  kBilinear = 0
7478 };
7479 
7491 inline Symbol SpatialTransformer(const std::string& symbol_name,
7492  Symbol data,
7493  Symbol loc,
7494  SpatialTransformerTransformType transform_type,
7495  SpatialTransformerSamplerType sampler_type,
7496  Shape target_shape = Shape(0,0),
7497  dmlc::optional<bool> cudnn_off = dmlc::optional<bool>()) {
7498  static const char *SpatialTransformerTransformTypeValues[] = {
7499  "affine"
7500  };
7501  static const char *SpatialTransformerSamplerTypeValues[] = {
7502  "bilinear"
7503  };
7504  return Operator("SpatialTransformer")
7505  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
7506  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
7507  .SetParam("target_shape", target_shape)
7508  .SetParam("cudnn_off", cudnn_off)
7509  .SetInput("data", data)
7510  .SetInput("loc", loc)
7511  .CreateSymbol(symbol_name);
7512 }
7513 
7517  kBatch = 0,
7518  kNull = 1,
7519  kValid = 2
7520 };
7521 
7617 inline Symbol SoftmaxOutput(const std::string& symbol_name,
7618  Symbol data,
7619  Symbol label,
7620  mx_float grad_scale = 1,
7621  mx_float ignore_label = -1,
7622  bool multi_output = false,
7623  bool use_ignore = false,
7624  bool preserve_shape = false,
7626  bool out_grad = false,
7627  mx_float smooth_alpha = 0) {
7628  static const char *SoftmaxOutputNormalizationValues[] = {
7629  "batch",
7630  "null",
7631  "valid"
7632  };
7633  return Operator("SoftmaxOutput")
7634  .SetParam("grad_scale", grad_scale)
7635  .SetParam("ignore_label", ignore_label)
7636  .SetParam("multi_output", multi_output)
7637  .SetParam("use_ignore", use_ignore)
7638  .SetParam("preserve_shape", preserve_shape)
7639  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
7640  .SetParam("out_grad", out_grad)
7641  .SetParam("smooth_alpha", smooth_alpha)
7642  .SetInput("data", data)
7643  .SetInput("label", label)
7644  .CreateSymbol(symbol_name);
7645 }
7646 
7650  kBatch = 0,
7651  kNull = 1,
7652  kValid = 2
7653 };
7654 
7682 inline Symbol Softmax(const std::string& symbol_name,
7683  Symbol data,
7684  mx_float grad_scale = 1,
7685  mx_float ignore_label = -1,
7686  bool multi_output = false,
7687  bool use_ignore = false,
7688  bool preserve_shape = false,
7690  bool out_grad = false,
7691  mx_float smooth_alpha = 0) {
7692  static const char *SoftmaxNormalizationValues[] = {
7693  "batch",
7694  "null",
7695  "valid"
7696  };
7697  return Operator("Softmax")
7698  .SetParam("grad_scale", grad_scale)
7699  .SetParam("ignore_label", ignore_label)
7700  .SetParam("multi_output", multi_output)
7701  .SetParam("use_ignore", use_ignore)
7702  .SetParam("preserve_shape", preserve_shape)
7703  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
7704  .SetParam("out_grad", out_grad)
7705  .SetParam("smooth_alpha", smooth_alpha)
7706  .SetInput("data", data)
7707  .CreateSymbol(symbol_name);
7708 }
7709 
7791 inline Symbol BilinearSampler(const std::string& symbol_name,
7792  Symbol data,
7793  Symbol grid,
7794  dmlc::optional<bool> cudnn_off = dmlc::optional<bool>()) {
7795  return Operator("BilinearSampler")
7796  .SetParam("cudnn_off", cudnn_off)
7797  .SetInput("data", data)
7798  .SetInput("grid", grid)
7799  .CreateSymbol(symbol_name);
7800 }
7801 
7858 inline Symbol ROIPooling(const std::string& symbol_name,
7859  Symbol data,
7860  Symbol rois,
7861  Shape pooled_size,
7862  mx_float spatial_scale) {
7863  return Operator("ROIPooling")
7864  .SetParam("pooled_size", pooled_size)
7865  .SetParam("spatial_scale", spatial_scale)
7866  .SetInput("data", data)
7867  .SetInput("rois", rois)
7868  .CreateSymbol(symbol_name);
7869 }
7870 
7926 inline Symbol SequenceLast(const std::string& symbol_name,
7927  Symbol data,
7928  Symbol sequence_length,
7929  bool use_sequence_length = false,
7930  int axis = 0) {
7931  return Operator("SequenceLast")
7932  .SetParam("use_sequence_length", use_sequence_length)
7933  .SetParam("axis", axis)
7934  .SetInput("data", data)
7935  .SetInput("sequence_length", sequence_length)
7936  .CreateSymbol(symbol_name);
7937 }
7938 
7942  kChannel = 0,
7943  kInstance = 1,
7944  kSpatial = 2
7945 };
7946 
8009 inline Symbol L2Normalization(const std::string& symbol_name,
8010  Symbol data,
8011  mx_float eps = 1e-10,
8013  static const char *L2NormalizationModeValues[] = {
8014  "channel",
8015  "instance",
8016  "spatial"
8017  };
8018  return Operator("L2Normalization")
8019  .SetParam("eps", eps)
8020  .SetParam("mode", L2NormalizationModeValues[int(mode)])
8021  .SetInput("data", data)
8022  .CreateSymbol(symbol_name);
8023 }
8024 
8030  kBatch = 0,
8031  kNull = 1,
8032  kValid = 2
8033 };
8034 
8069 inline Symbol MakeLoss(const std::string& symbol_name,
8070  Symbol data,
8071  mx_float grad_scale = 1,
8072  mx_float valid_thresh = 0,
8074  static const char *MakeLossNormalizationValues[] = {
8075  "batch",
8076  "null",
8077  "valid"
8078  };
8079  return Operator("MakeLoss")
8080  .SetParam("grad_scale", grad_scale)
8081  .SetParam("valid_thresh", valid_thresh)
8082  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
8083  .SetInput("data", data)
8084  .CreateSymbol(symbol_name);
8085 }
8086 
8102 inline Symbol SVMOutput(const std::string& symbol_name,
8103  Symbol data,
8104  Symbol label,
8105  mx_float margin = 1,
8106  mx_float regularization_coefficient = 1,
8107  bool use_linear = false) {
8108  return Operator("SVMOutput")
8109  .SetParam("margin", margin)
8110  .SetParam("regularization_coefficient", regularization_coefficient)
8111  .SetParam("use_linear", use_linear)
8112  .SetInput("data", data)
8113  .SetInput("label", label)
8114  .CreateSymbol(symbol_name);
8115 }
8116 
8166 inline Symbol Correlation(const std::string& symbol_name,
8167  Symbol data1,
8168  Symbol data2,
8169  uint32_t kernel_size = 1,
8170  uint32_t max_displacement = 1,
8171  uint32_t stride1 = 1,
8172  uint32_t stride2 = 1,
8173  uint32_t pad_size = 0,
8174  bool is_multiply = true) {
8175  return Operator("Correlation")
8176  .SetParam("kernel_size", kernel_size)
8177  .SetParam("max_displacement", max_displacement)
8178  .SetParam("stride1", stride1)
8179  .SetParam("stride2", stride2)
8180  .SetParam("pad_size", pad_size)
8181  .SetParam("is_multiply", is_multiply)
8182  .SetInput("data1", data1)
8183  .SetInput("data2", data2)
8184  .CreateSymbol(symbol_name);
8185 }
8186 
8265 inline Symbol SequenceMask(const std::string& symbol_name,
8266  Symbol data,
8267  Symbol sequence_length,
8268  bool use_sequence_length = false,
8269  mx_float value = 0,
8270  int axis = 0) {
8271  return Operator("SequenceMask")
8272  .SetParam("use_sequence_length", use_sequence_length)
8273  .SetParam("value", value)
8274  .SetParam("axis", axis)
8275  .SetInput("data", data)
8276  .SetInput("sequence_length", sequence_length)
8277  .CreateSymbol(symbol_name);
8278 }
8279 
8288 inline Symbol choose_element_0index(const std::string& symbol_name,
8289  Symbol lhs,
8290  Symbol rhs) {
8291  return Operator("choose_element_0index")
8292  .SetInput("lhs", lhs)
8293  .SetInput("rhs", rhs)
8294  .CreateSymbol(symbol_name);
8295 }
8296 
8306 inline Symbol fill_element_0index(const std::string& symbol_name,
8307  Symbol lhs,
8308  Symbol mhs,
8309  Symbol rhs) {
8310  return Operator("fill_element_0index")
8311  .SetInput("lhs", lhs)
8312  .SetInput("mhs", mhs)
8313  .SetInput("rhs", rhs)
8314  .CreateSymbol(symbol_name);
8315 }
8316 
8356 inline Symbol khatri_rao(const std::vector<Symbol>& args) {
8357  return Operator("khatri_rao")
8358 (args)
8359  .CreateSymbol();
8360 }
8361 
8376 inline Symbol Custom(const std::vector<Symbol>& data,
8377  const std::string& op_type) {
8378  return Operator("Custom")
8379 (data)
8380  .CreateSymbol();
8381 }
8382 
8405  Symbol rhs) {
8406  return Operator("broadcast_power")
8407  .SetInput("lhs", lhs)
8408  .SetInput("rhs", rhs)
8409  .CreateSymbol();
8410 }
8411 
8436  Symbol rhs) {
8437  return Operator("broadcast_maximum")
8438  .SetInput("lhs", lhs)
8439  .SetInput("rhs", rhs)
8440  .CreateSymbol();
8441 }
8442 
8467  Symbol rhs) {
8468  return Operator("broadcast_minimum")
8469  .SetInput("lhs", lhs)
8470  .SetInput("rhs", rhs)
8471  .CreateSymbol();
8472 }
8473 
8504  Symbol rhs) {
8505  return Operator("broadcast_hypot")
8506  .SetInput("lhs", lhs)
8507  .SetInput("rhs", rhs)
8508  .CreateSymbol();
8509 }
8510 
8584 inline Symbol Reshape(Symbol data,
8585  Shape shape = Shape(),
8586  bool reverse = false,
8587  Shape target_shape = Shape(),
8588  bool keep_highest = false) {
8589  return Operator("Reshape")
8590  .SetParam("shape", shape)
8591  .SetParam("reverse", reverse)
8592  .SetParam("target_shape", target_shape)
8593  .SetParam("keep_highest", keep_highest)
8594  .SetInput("data", data)
8595  .CreateSymbol();
8596 }
8597 
8630 inline Symbol Flatten(Symbol data) {
8631  return Operator("Flatten")
8632  .SetInput("data", data)
8633  .CreateSymbol();
8634 }
8635 
8672  Shape axes = Shape()) {
8673  return Operator("transpose")
8674  .SetParam("axes", axes)
8675  .SetInput("data", data)
8676  .CreateSymbol();
8677 }
8678 
8694  int axis) {
8695  return Operator("expand_dims")
8696  .SetParam("axis", axis)
8697  .SetInput("data", data)
8698  .CreateSymbol();
8699 }
8700 
8755 inline Symbol slice(Symbol data,
8756  Shape begin,
8757  Shape end,
8758  Shape step = Shape()) {
8759  return Operator("slice")
8760  .SetParam("begin", begin)
8761  .SetParam("end", end)
8762  .SetParam("step", step)
8763  .SetInput("data", data)
8764  .CreateSymbol();
8765 }
8766 
8799  int axis,
8800  int begin,
8801  dmlc::optional<int> end) {
8802  return Operator("slice_axis")
8803  .SetParam("axis", axis)
8804  .SetParam("begin", begin)
8805  .SetParam("end", end)
8806  .SetInput("data", data)
8807  .CreateSymbol();
8808 }
8809 
8871  Symbol shape_like,
8872  Shape axes = Shape()) {
8873  return Operator("slice_like")
8874  .SetParam("axes", axes)
8875  .SetInput("data", data)
8876  .SetInput("shape_like", shape_like)
8877  .CreateSymbol();
8878 }
8879 
8913 inline Symbol clip(Symbol data,
8914  mx_float a_min,
8915  mx_float a_max) {
8916  return Operator("clip")
8917  .SetParam("a_min", a_min)
8918  .SetParam("a_max", a_max)
8919  .SetInput("data", data)
8920  .CreateSymbol();
8921 }
8922 
8956 inline Symbol repeat(Symbol data,
8957  int repeats,
8958  dmlc::optional<int> axis = dmlc::optional<int>()) {
8959  return Operator("repeat")
8960  .SetParam("repeats", repeats)
8961  .SetParam("axis", axis)
8962  .SetInput("data", data)
8963  .CreateSymbol();
8964 }
8965 
9010 inline Symbol tile(Symbol data,
9011  Shape reps) {
9012  return Operator("tile")
9013  .SetParam("reps", reps)
9014  .SetInput("data", data)
9015  .CreateSymbol();
9016 }
9017 
9040 inline Symbol reverse(Symbol data,
9041  Shape axis) {
9042  return Operator("reverse")
9043  .SetParam("axis", axis)
9044  .SetInput("data", data)
9045  .CreateSymbol();
9046 }
9047 
9070 inline Symbol stack(const std::vector<Symbol>& data,
9071  int num_args,
9072  int axis = 0) {
9073  return Operator("stack")
9074  .SetParam("num_args", num_args)
9075  .SetParam("axis", axis)
9076 (data)
9077  .CreateSymbol();
9078 }
9079 
9101 inline Symbol squeeze(const std::vector<Symbol>& data,
9102  dmlc::optional<Shape> axis = dmlc::optional<Shape>()) {
9103  return Operator("squeeze")
9104  .SetParam("axis", axis)
9105 (data)
9106  .CreateSymbol();
9107 }
9108 
9150  int block_size) {
9151  return Operator("depth_to_space")
9152  .SetParam("block_size", block_size)
9153  .SetInput("data", data)
9154  .CreateSymbol();
9155 }
9156 
9200  int block_size) {
9201  return Operator("space_to_depth")
9202  .SetParam("block_size", block_size)
9203  .SetInput("data", data)
9204  .CreateSymbol();
9205 }
9206 
9229 inline Symbol zeros_like(Symbol data) {
9230  return Operator("zeros_like")
9231  .SetInput("data", data)
9232  .CreateSymbol();
9233 }
9234 
9251 inline Symbol ones_like(Symbol data) {
9252  return Operator("ones_like")
9253  .SetInput("data", data)
9254  .CreateSymbol();
9255 }
9256 
9278 inline Symbol add_n(const std::vector<Symbol>& args) {
9279  return Operator("add_n")
9280 (args)
9281  .CreateSymbol();
9282 }
9283 
9314 inline Symbol argmax(Symbol data,
9315  dmlc::optional<int> axis = dmlc::optional<int>(),
9316  bool keepdims = false) {
9317  return Operator("argmax")
9318  .SetParam("axis", axis)
9319  .SetParam("keepdims", keepdims)
9320  .SetInput("data", data)
9321  .CreateSymbol();
9322 }
9323 
9354 inline Symbol argmin(Symbol data,
9355  dmlc::optional<int> axis = dmlc::optional<int>(),
9356  bool keepdims = false) {
9357  return Operator("argmin")
9358  .SetParam("axis", axis)
9359  .SetParam("keepdims", keepdims)
9360  .SetInput("data", data)
9361  .CreateSymbol();
9362 }
9363 
9386  return Operator("argmax_channel")
9387  .SetInput("data", data)
9388  .CreateSymbol();
9389 }
9390 
9446 inline Symbol pick(Symbol data,
9447  Symbol index,
9448  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9449  bool keepdims = false,
9450  PickMode mode = PickMode::kClip) {
9451  static const char *PickModeValues[] = {
9452  "clip",
9453  "wrap"
9454  };
9455  return Operator("pick")
9456  .SetParam("axis", axis)
9457  .SetParam("keepdims", keepdims)
9458  .SetParam("mode", PickModeValues[int(mode)])
9459  .SetInput("data", data)
9460  .SetInput("index", index)
9461  .CreateSymbol();
9462 }
9463 
9521 inline Symbol dot(Symbol lhs,
9522  Symbol rhs,
9523  bool transpose_a = false,
9524  bool transpose_b = false,
9525  DotForwardStype forward_stype = DotForwardStype::kNone) {
9526  static const char *DotForwardStypeValues[] = {
9527  "None",
9528  "csr",
9529  "default",
9530  "row_sparse"
9531  };
9532  return Operator("dot")
9533  .SetParam("transpose_a", transpose_a)
9534  .SetParam("transpose_b", transpose_b)
9535  .SetParam("forward_stype", DotForwardStypeValues[int(forward_stype)])
9536  .SetInput("lhs", lhs)
9537  .SetInput("rhs", rhs)
9538  .CreateSymbol();
9539 }
9540 
9566  Symbol rhs,
9567  bool transpose_a = false,
9568  bool transpose_b = false,
9570  static const char *Batch_dotForwardStypeValues[] = {
9571  "None",
9572  "csr",
9573  "default",
9574  "row_sparse"
9575  };
9576  return Operator("batch_dot")
9577  .SetParam("transpose_a", transpose_a)
9578  .SetParam("transpose_b", transpose_b)
9579  .SetParam("forward_stype", Batch_dotForwardStypeValues[int(forward_stype)])
9580  .SetInput("lhs", lhs)
9581  .SetInput("rhs", rhs)
9582  .CreateSymbol();
9583 }
9584 
9617  Symbol rhs) {
9618  return Operator("broadcast_add")
9619  .SetInput("lhs", lhs)
9620  .SetInput("rhs", rhs)
9621  .CreateSymbol();
9622 }
9623 
9656  Symbol rhs) {
9657  return Operator("broadcast_sub")
9658  .SetInput("lhs", lhs)
9659  .SetInput("rhs", rhs)
9660  .CreateSymbol();
9661 }
9662 
9689  Symbol rhs) {
9690  return Operator("broadcast_mul")
9691  .SetInput("lhs", lhs)
9692  .SetInput("rhs", rhs)
9693  .CreateSymbol();
9694 }
9695 
9722  Symbol rhs) {
9723  return Operator("broadcast_div")
9724  .SetInput("lhs", lhs)
9725  .SetInput("rhs", rhs)
9726  .CreateSymbol();
9727 }
9728 
9751  Symbol rhs) {
9752  return Operator("broadcast_mod")
9753  .SetInput("lhs", lhs)
9754  .SetInput("rhs", rhs)
9755  .CreateSymbol();
9756 }
9757 
9776 inline Symbol relu(Symbol data) {
9777  return Operator("relu")
9778  .SetInput("data", data)
9779  .CreateSymbol();
9780 }
9781 
9796 inline Symbol sigmoid(Symbol data) {
9797  return Operator("sigmoid")
9798  .SetInput("data", data)
9799  .CreateSymbol();
9800 }
9801 
9817  mx_float alpha = 0.2,
9818  mx_float beta = 0.5) {
9819  return Operator("hard_sigmoid")
9820  .SetParam("alpha", alpha)
9821  .SetParam("beta", beta)
9822  .SetInput("data", data)
9823  .CreateSymbol();
9824 }
9825 
9840 inline Symbol softsign(Symbol data) {
9841  return Operator("softsign")
9842  .SetInput("data", data)
9843  .CreateSymbol();
9844 }
9845 
9878 inline Symbol BlockGrad(Symbol data) {
9879  return Operator("BlockGrad")
9880  .SetInput("data", data)
9881  .CreateSymbol();
9882 }
9883 
9912 inline Symbol make_loss(Symbol data) {
9913  return Operator("make_loss")
9914  .SetInput("data", data)
9915  .CreateSymbol();
9916 }
9917 
9952  Symbol rhs) {
9953  return Operator("reshape_like")
9954  .SetInput("lhs", lhs)
9955  .SetInput("rhs", rhs)
9956  .CreateSymbol();
9957 }
9958 
9977  dmlc::optional<int> lhs_begin = dmlc::optional<int>(),
9978  dmlc::optional<int> lhs_end = dmlc::optional<int>(),
9979  dmlc::optional<int> rhs_begin = dmlc::optional<int>(),
9980  dmlc::optional<int> rhs_end = dmlc::optional<int>()) {
9981  return Operator("shape_array")
9982  .SetParam("lhs_begin", lhs_begin)
9983  .SetParam("lhs_end", lhs_end)
9984  .SetParam("rhs_begin", rhs_begin)
9985  .SetParam("rhs_end", rhs_end)
9986  .SetInput("data", data)
9987  .CreateSymbol();
9988 }
9989 
10003 inline Symbol size_array(Symbol data) {
10004  return Operator("size_array")
10005  .SetInput("data", data)
10006  .CreateSymbol();
10007 }
10008 
10027 inline Symbol Cast(Symbol data,
10028  CastDtype dtype) {
10029  static const char *CastDtypeValues[] = {
10030  "float16",
10031  "float32",
10032  "float64",
10033  "int32",
10034  "int64",
10035  "int8",
10036  "uint8"
10037  };
10038  return Operator("Cast")
10039  .SetParam("dtype", CastDtypeValues[int(dtype)])
10040  .SetInput("data", data)
10041  .CreateSymbol();
10042 }
10043 
10057 inline Symbol negative(Symbol data) {
10058  return Operator("negative")
10059  .SetInput("data", data)
10060  .CreateSymbol();
10061 }
10062 
10078 inline Symbol reciprocal(Symbol data) {
10079  return Operator("reciprocal")
10080  .SetInput("data", data)
10081  .CreateSymbol();
10082 }
10083 
10103 inline Symbol abs(Symbol data) {
10104  return Operator("abs")
10105  .SetInput("data", data)
10106  .CreateSymbol();
10107 }
10108 
10128 inline Symbol sign(Symbol data) {
10129  return Operator("sign")
10130  .SetInput("data", data)
10131  .CreateSymbol();
10132 }
10133 
10153 inline Symbol round(Symbol data) {
10154  return Operator("round")
10155  .SetInput("data", data)
10156  .CreateSymbol();
10157 }
10158 
10182 inline Symbol rint(Symbol data) {
10183  return Operator("rint")
10184  .SetInput("data", data)
10185  .CreateSymbol();
10186 }
10187 
10209 inline Symbol ceil(Symbol data) {
10210  return Operator("ceil")
10211  .SetInput("data", data)
10212  .CreateSymbol();
10213 }
10214 
10236 inline Symbol floor(Symbol data) {
10237  return Operator("floor")
10238  .SetInput("data", data)
10239  .CreateSymbol();
10240 }
10241 
10264 inline Symbol trunc(Symbol data) {
10265  return Operator("trunc")
10266  .SetInput("data", data)
10267  .CreateSymbol();
10268 }
10269 
10290 inline Symbol fix(Symbol data) {
10291  return Operator("fix")
10292  .SetInput("data", data)
10293  .CreateSymbol();
10294 }
10295 
10318 inline Symbol square(Symbol data) {
10319  return Operator("square")
10320  .SetInput("data", data)
10321  .CreateSymbol();
10322 }
10323 
10346 inline Symbol sqrt(Symbol data) {
10347  return Operator("sqrt")
10348  .SetInput("data", data)
10349  .CreateSymbol();
10350 }
10351 
10370 inline Symbol rsqrt(Symbol data) {
10371  return Operator("rsqrt")
10372  .SetInput("data", data)
10373  .CreateSymbol();
10374 }
10375 
10398 inline Symbol cbrt(Symbol data) {
10399  return Operator("cbrt")
10400  .SetInput("data", data)
10401  .CreateSymbol();
10402 }
10403 
10417 inline Symbol erf(Symbol data) {
10418  return Operator("erf")
10419  .SetInput("data", data)
10420  .CreateSymbol();
10421 }
10422 
10439 inline Symbol rcbrt(Symbol data) {
10440  return Operator("rcbrt")
10441  .SetInput("data", data)
10442  .CreateSymbol();
10443 }
10444 
10463 inline Symbol exp(Symbol data) {
10464  return Operator("exp")
10465  .SetInput("data", data)
10466  .CreateSymbol();
10467 }
10468 
10482 inline Symbol log(Symbol data) {
10483  return Operator("log")
10484  .SetInput("data", data)
10485  .CreateSymbol();
10486 }
10487 
10501 inline Symbol log10(Symbol data) {
10502  return Operator("log10")
10503  .SetInput("data", data)
10504  .CreateSymbol();
10505 }
10506 
10520 inline Symbol log2(Symbol data) {
10521  return Operator("log2")
10522  .SetInput("data", data)
10523  .CreateSymbol();
10524 }
10525 
10544 inline Symbol log1p(Symbol data) {
10545  return Operator("log1p")
10546  .SetInput("data", data)
10547  .CreateSymbol();
10548 }
10549 
10567 inline Symbol expm1(Symbol data) {
10568  return Operator("expm1")
10569  .SetInput("data", data)
10570  .CreateSymbol();
10571 }
10572 
10583 inline Symbol gamma(Symbol data) {
10584  return Operator("gamma")
10585  .SetInput("data", data)
10586  .CreateSymbol();
10587 }
10588 
10599 inline Symbol gammaln(Symbol data) {
10600  return Operator("gammaln")
10601  .SetInput("data", data)
10602  .CreateSymbol();
10603 }
10604 
10615 inline Symbol logical_not(Symbol data) {
10616  return Operator("logical_not")
10617  .SetInput("data", data)
10618  .CreateSymbol();
10619 }
10620 
10678 inline Symbol sum(Symbol data,
10679  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10680  bool keepdims = false,
10681  bool exclude = false) {
10682  return Operator("sum")
10683  .SetParam("axis", axis)
10684  .SetParam("keepdims", keepdims)
10685  .SetParam("exclude", exclude)
10686  .SetInput("data", data)
10687  .CreateSymbol();
10688 }
10689 
10713 inline Symbol mean(Symbol data,
10714  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10715  bool keepdims = false,
10716  bool exclude = false) {
10717  return Operator("mean")
10718  .SetParam("axis", axis)
10719  .SetParam("keepdims", keepdims)
10720  .SetParam("exclude", exclude)
10721  .SetInput("data", data)
10722  .CreateSymbol();
10723 }
10724 
10748 inline Symbol prod(Symbol data,
10749  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10750  bool keepdims = false,
10751  bool exclude = false) {
10752  return Operator("prod")
10753  .SetParam("axis", axis)
10754  .SetParam("keepdims", keepdims)
10755  .SetParam("exclude", exclude)
10756  .SetInput("data", data)
10757  .CreateSymbol();
10758 }
10759 
10785 inline Symbol nansum(Symbol data,
10786  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10787  bool keepdims = false,
10788  bool exclude = false) {
10789  return Operator("nansum")
10790  .SetParam("axis", axis)
10791  .SetParam("keepdims", keepdims)
10792  .SetParam("exclude", exclude)
10793  .SetInput("data", data)
10794  .CreateSymbol();
10795 }
10796 
10822 inline Symbol nanprod(Symbol data,
10823  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10824  bool keepdims = false,
10825  bool exclude = false) {
10826  return Operator("nanprod")
10827  .SetParam("axis", axis)
10828  .SetParam("keepdims", keepdims)
10829  .SetParam("exclude", exclude)
10830  .SetInput("data", data)
10831  .CreateSymbol();
10832 }
10833 
10857 inline Symbol max(Symbol data,
10858  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10859  bool keepdims = false,
10860  bool exclude = false) {
10861  return Operator("max")
10862  .SetParam("axis", axis)
10863  .SetParam("keepdims", keepdims)
10864  .SetParam("exclude", exclude)
10865  .SetInput("data", data)
10866  .CreateSymbol();
10867 }
10868 
10892 inline Symbol min(Symbol data,
10893  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10894  bool keepdims = false,
10895  bool exclude = false) {
10896  return Operator("min")
10897  .SetParam("axis", axis)
10898  .SetParam("keepdims", keepdims)
10899  .SetParam("exclude", exclude)
10900  .SetInput("data", data)
10901  .CreateSymbol();
10902 }
10903 
10933  Shape axis = Shape(),
10934  Shape size = Shape()) {
10935  return Operator("broadcast_axis")
10936  .SetParam("axis", axis)
10937  .SetParam("size", size)
10938  .SetInput("data", data)
10939  .CreateSymbol();
10940 }
10941 
10970  Shape shape = Shape()) {
10971  return Operator("broadcast_to")
10972  .SetParam("shape", shape)
10973  .SetInput("data", data)
10974  .CreateSymbol();
10975 }
10976 
11005  Symbol rhs,
11006  dmlc::optional<Shape> lhs_axes = dmlc::optional<Shape>(),
11007  dmlc::optional<Shape> rhs_axes = dmlc::optional<Shape>()) {
11008  return Operator("broadcast_like")
11009  .SetParam("lhs_axes", lhs_axes)
11010  .SetParam("rhs_axes", rhs_axes)
11011  .SetInput("lhs", lhs)
11012  .SetInput("rhs", rhs)
11013  .CreateSymbol();
11014 }
11015 
11058 inline Symbol norm(Symbol data,
11059  int ord = 2,
11060  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
11061  bool keepdims = false) {
11062  return Operator("norm")
11063  .SetParam("ord", ord)
11064  .SetParam("axis", axis)
11065  .SetParam("keepdims", keepdims)
11066  .SetInput("data", data)
11067  .CreateSymbol();
11068 }
11069 
11112 inline Symbol topk(Symbol data,
11113  dmlc::optional<int> axis = dmlc::optional<int>(-1),
11114  int k = 1,
11115  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
11116  bool is_ascend = false,
11117  TopkDtype dtype = TopkDtype::kFloat32) {
11118  static const char *TopkRetTypValues[] = {
11119  "both",
11120  "indices",
11121  "mask",
11122  "value"
11123  };
11124  static const char *TopkDtypeValues[] = {
11125  "float16",
11126  "float32",
11127  "float64",
11128  "int32",
11129  "uint8"
11130  };
11131  return Operator("topk")
11132  .SetParam("axis", axis)
11133  .SetParam("k", k)
11134  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
11135  .SetParam("is_ascend", is_ascend)
11136  .SetParam("dtype", TopkDtypeValues[int(dtype)])
11137  .SetInput("data", data)
11138  .CreateSymbol();
11139 }
11140 
11172 inline Symbol sort(Symbol data,
11173  dmlc::optional<int> axis = dmlc::optional<int>(-1),
11174  bool is_ascend = true) {
11175  return Operator("sort")
11176  .SetParam("axis", axis)
11177  .SetParam("is_ascend", is_ascend)
11178  .SetInput("data", data)
11179  .CreateSymbol();
11180 }
11181 
11213 inline Symbol argsort(Symbol data,
11214  dmlc::optional<int> axis = dmlc::optional<int>(-1),
11215  bool is_ascend = true,
11217  static const char *ArgsortDtypeValues[] = {
11218  "float16",
11219  "float32",
11220  "float64",
11221  "int32",
11222  "uint8"
11223  };
11224  return Operator("argsort")
11225  .SetParam("axis", axis)
11226  .SetParam("is_ascend", is_ascend)
11227  .SetParam("dtype", ArgsortDtypeValues[int(dtype)])
11228  .SetInput("data", data)
11229  .CreateSymbol();
11230 }
11231 
11251  Symbol rhs) {
11252  return Operator("elemwise_add")
11253  .SetInput("lhs", lhs)
11254  .SetInput("rhs", rhs)
11255  .CreateSymbol();
11256 }
11257 
11277  Symbol rhs) {
11278  return Operator("elemwise_sub")
11279  .SetInput("lhs", lhs)
11280  .SetInput("rhs", rhs)
11281  .CreateSymbol();
11282 }
11283 
11302  Symbol rhs) {
11303  return Operator("elemwise_mul")
11304  .SetInput("lhs", lhs)
11305  .SetInput("rhs", rhs)
11306  .CreateSymbol();
11307 }
11308 
11320  Symbol rhs) {
11321  return Operator("elemwise_div")
11322  .SetInput("lhs", lhs)
11323  .SetInput("rhs", rhs)
11324  .CreateSymbol();
11325 }
11326 
11390  Symbol weight,
11391  int input_dim,
11392  int output_dim,
11394  bool sparse_grad = false) {
11395  static const char *EmbeddingDtypeValues[] = {
11396  "float16",
11397  "float32",
11398  "float64",
11399  "int32",
11400  "int64",
11401  "int8",
11402  "uint8"
11403  };
11404  return Operator("Embedding")
11405  .SetParam("input_dim", input_dim)
11406  .SetParam("output_dim", output_dim)
11407  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
11408  .SetParam("sparse_grad", sparse_grad)
11409  .SetInput("data", data)
11410  .SetInput("weight", weight)
11411  .CreateSymbol();
11412 }
11413 
11472 inline Symbol take(Symbol a,
11473  Symbol indices,
11474  int axis = 0,
11475  TakeMode mode = TakeMode::kClip) {
11476  static const char *TakeModeValues[] = {
11477  "clip",
11478  "raise",
11479  "wrap"
11480  };
11481  return Operator("take")
11482  .SetParam("axis", axis)
11483  .SetParam("mode", TakeModeValues[int(mode)])
11484  .SetInput("a", a)
11485  .SetInput("indices", indices)
11486  .CreateSymbol();
11487 }
11488 
11517  Symbol indices) {
11518  return Operator("batch_take")
11519  .SetInput("a", a)
11520  .SetInput("indices", indices)
11521  .CreateSymbol();
11522 }
11523 
11567 inline Symbol one_hot(Symbol indices,
11568  int depth,
11569  double on_value = 1,
11570  double off_value = 0,
11572  static const char *One_hotDtypeValues[] = {
11573  "float16",
11574  "float32",
11575  "float64",
11576  "int32",
11577  "int64",
11578  "int8",
11579  "uint8"
11580  };
11581  return Operator("one_hot")
11582  .SetParam("depth", depth)
11583  .SetParam("on_value", on_value)
11584  .SetParam("off_value", off_value)
11585  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
11586  .SetInput("indices", indices)
11587  .CreateSymbol();
11588 }
11589 
11621  Symbol indices) {
11622  return Operator("gather_nd")
11623  .SetInput("data", data)
11624  .SetInput("indices", indices)
11625  .CreateSymbol();
11626 }
11627 
11679  Symbol indices,
11680  Shape shape) {
11681  return Operator("scatter_nd")
11682  .SetParam("shape", shape)
11683  .SetInput("data", data)
11684  .SetInput("indices", indices)
11685  .CreateSymbol();
11686 }
11687 
11710  Symbol rhs) {
11711  return Operator("broadcast_equal")
11712  .SetInput("lhs", lhs)
11713  .SetInput("rhs", rhs)
11714  .CreateSymbol();
11715 }
11716 
11739  Symbol rhs) {
11740  return Operator("broadcast_not_equal")
11741  .SetInput("lhs", lhs)
11742  .SetInput("rhs", rhs)
11743  .CreateSymbol();
11744 }
11745 
11768  Symbol rhs) {
11769  return Operator("broadcast_greater")
11770  .SetInput("lhs", lhs)
11771  .SetInput("rhs", rhs)
11772  .CreateSymbol();
11773 }
11774 
11797  Symbol rhs) {
11798  return Operator("broadcast_greater_equal")
11799  .SetInput("lhs", lhs)
11800  .SetInput("rhs", rhs)
11801  .CreateSymbol();
11802 }
11803 
11826  Symbol rhs) {
11827  return Operator("broadcast_lesser")
11828  .SetInput("lhs", lhs)
11829  .SetInput("rhs", rhs)
11830  .CreateSymbol();
11831 }
11832 
11855  Symbol rhs) {
11856  return Operator("broadcast_lesser_equal")
11857  .SetInput("lhs", lhs)
11858  .SetInput("rhs", rhs)
11859  .CreateSymbol();
11860 }
11861 
11884  Symbol rhs) {
11885  return Operator("broadcast_logical_and")
11886  .SetInput("lhs", lhs)
11887  .SetInput("rhs", rhs)
11888  .CreateSymbol();
11889 }
11890 
11913  Symbol rhs) {
11914  return Operator("broadcast_logical_or")
11915  .SetInput("lhs", lhs)
11916  .SetInput("rhs", rhs)
11917  .CreateSymbol();
11918 }
11919 
11942  Symbol rhs) {
11943  return Operator("broadcast_logical_xor")
11944  .SetInput("lhs", lhs)
11945  .SetInput("rhs", rhs)
11946  .CreateSymbol();
11947 }
11948 
12012 inline Symbol diag(Symbol data,
12013  int k = 0,
12014  int axis1 = 0,
12015  int axis2 = 1) {
12016  return Operator("diag")
12017  .SetParam("k", k)
12018  .SetParam("axis1", axis1)
12019  .SetParam("axis2", axis2)
12020  .SetInput("data", data)
12021  .CreateSymbol();
12022 }
12023 
12058 inline Symbol where(Symbol condition,
12059  Symbol x,
12060  Symbol y) {
12061  return Operator("where")
12062  .SetInput("condition", condition)
12063  .SetInput("x", x)
12064  .SetInput("y", y)
12065  .CreateSymbol();
12066 }
12067 
12094  mx_float scalar) {
12095  return Operator("smooth_l1")
12096  .SetParam("scalar", scalar)
12097  .SetInput("data", data)
12098  .CreateSymbol();
12099 }
12100 
12146  Cast_storageStype stype) {
12147  static const char *Cast_storageStypeValues[] = {
12148  "csr",
12149  "default",
12150  "row_sparse"
12151  };
12152  return Operator("cast_storage")
12153  .SetParam("stype", Cast_storageStypeValues[int(stype)])
12154  .SetInput("data", data)
12155  .CreateSymbol();
12156 }
12157 
12178 inline Symbol sin(Symbol data) {
12179  return Operator("sin")
12180  .SetInput("data", data)
12181  .CreateSymbol();
12182 }
12183 
12200 inline Symbol cos(Symbol data) {
12201  return Operator("cos")
12202  .SetInput("data", data)
12203  .CreateSymbol();
12204 }
12205 
12226 inline Symbol tan(Symbol data) {
12227  return Operator("tan")
12228  .SetInput("data", data)
12229  .CreateSymbol();
12230 }
12231 
12253 inline Symbol arcsin(Symbol data) {
12254  return Operator("arcsin")
12255  .SetInput("data", data)
12256  .CreateSymbol();
12257 }
12258 
12276 inline Symbol arccos(Symbol data) {
12277  return Operator("arccos")
12278  .SetInput("data", data)
12279  .CreateSymbol();
12280 }
12281 
12302 inline Symbol arctan(Symbol data) {
12303  return Operator("arctan")
12304  .SetInput("data", data)
12305  .CreateSymbol();
12306 }
12307 
12326 inline Symbol degrees(Symbol data) {
12327  return Operator("degrees")
12328  .SetInput("data", data)
12329  .CreateSymbol();
12330 }
12331 
12350 inline Symbol radians(Symbol data) {
12351  return Operator("radians")
12352  .SetInput("data", data)
12353  .CreateSymbol();
12354 }
12355 
12374 inline Symbol sinh(Symbol data) {
12375  return Operator("sinh")
12376  .SetInput("data", data)
12377  .CreateSymbol();
12378 }
12379 
12394 inline Symbol cosh(Symbol data) {
12395  return Operator("cosh")
12396  .SetInput("data", data)
12397  .CreateSymbol();
12398 }
12399 
12418 inline Symbol tanh(Symbol data) {
12419  return Operator("tanh")
12420  .SetInput("data", data)
12421  .CreateSymbol();
12422 }
12423 
12440 inline Symbol arcsinh(Symbol data) {
12441  return Operator("arcsinh")
12442  .SetInput("data", data)
12443  .CreateSymbol();
12444 }
12445 
12458 inline Symbol arccosh(Symbol data) {
12459  return Operator("arccosh")
12460  .SetInput("data", data)
12461  .CreateSymbol();
12462 }
12463 
12480 inline Symbol arctanh(Symbol data) {
12481  return Operator("arctanh")
12482  .SetInput("data", data)
12483  .CreateSymbol();
12484 }
12485 
12553 inline Symbol Pooling(Symbol data,
12554  Shape kernel = Shape(),
12556  bool global_pool = false,
12557  bool cudnn_off = false,
12559  Shape stride = Shape(),
12560  Shape pad = Shape(),
12561  dmlc::optional<int> p_value = dmlc::optional<int>(),
12562  dmlc::optional<bool> count_include_pad = dmlc::optional<bool>()) {
12563  static const char *PoolingPoolTypeValues[] = {
12564  "avg",
12565  "lp",
12566  "max",
12567  "sum"
12568  };
12569  static const char *PoolingPoolingConventionValues[] = {
12570  "full",
12571  "same",
12572  "valid"
12573  };
12574  return Operator("Pooling")
12575  .SetParam("kernel", kernel)
12576  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
12577  .SetParam("global_pool", global_pool)
12578  .SetParam("cudnn_off", cudnn_off)
12579  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
12580  .SetParam("stride", stride)
12581  .SetParam("pad", pad)
12582  .SetParam("p_value", p_value)
12583  .SetParam("count_include_pad", count_include_pad)
12584  .SetInput("data", data)
12585  .CreateSymbol();
12586 }
12587 
12619 inline Symbol softmax(Symbol data,
12620  int axis = -1,
12621  dmlc::optional<double> temperature = dmlc::optional<double>()) {
12622  return Operator("softmax")
12623  .SetParam("axis", axis)
12624  .SetParam("temperature", temperature)
12625  .SetInput("data", data)
12626  .CreateSymbol();
12627 }
12628 
12661 inline Symbol softmin(Symbol data,
12662  int axis = -1,
12663  dmlc::optional<double> temperature = dmlc::optional<double>()) {
12664  return Operator("softmin")
12665  .SetParam("axis", axis)
12666  .SetParam("temperature", temperature)
12667  .SetInput("data", data)
12668  .CreateSymbol();
12669 }
12670 
12694  int axis = -1,
12695  dmlc::optional<double> temperature = dmlc::optional<double>()) {
12696  return Operator("log_softmax")
12697  .SetParam("axis", axis)
12698  .SetParam("temperature", temperature)
12699  .SetInput("data", data)
12700  .CreateSymbol();
12701 }
12702 
12732  Symbol weight,
12733  Symbol bias,
12734  Shape kernel,
12735  uint32_t num_filter,
12736  Shape stride = Shape(),
12737  Shape dilate = Shape(),
12738  Shape pad = Shape(),
12739  Shape adj = Shape(),
12740  Shape target_shape = Shape(),
12741  uint32_t num_group = 1,
12742  uint64_t workspace = 512,
12743  bool no_bias = true,
12745  bool cudnn_off = false,
12747  static const char *DeconvolutionCudnnTuneValues[] = {
12748  "None",
12749  "fastest",
12750  "limited_workspace",
12751  "off"
12752  };
12753  static const char *DeconvolutionLayoutValues[] = {
12754  "None",
12755  "NCDHW",
12756  "NCHW",
12757  "NCW",
12758  "NDHWC",
12759  "NHWC"
12760  };
12761  return Operator("Deconvolution")
12762  .SetParam("kernel", kernel)
12763  .SetParam("num_filter", num_filter)
12764  .SetParam("stride", stride)
12765  .SetParam("dilate", dilate)
12766  .SetParam("pad", pad)
12767  .SetParam("adj", adj)
12768  .SetParam("target_shape", target_shape)
12769  .SetParam("num_group", num_group)
12770  .SetParam("workspace", workspace)
12771  .SetParam("no_bias", no_bias)
12772  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
12773  .SetParam("cudnn_off", cudnn_off)
12774  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
12775  .SetInput("data", data)
12776  .SetInput("weight", weight)
12777  .SetInput("bias", bias)
12778  .CreateSymbol();
12779 }
12780 
12800  ActivationActType act_type) {
12801  static const char *ActivationActTypeValues[] = {
12802  "relu",
12803  "sigmoid",
12804  "softrelu",
12805  "softsign",
12806  "tanh"
12807  };
12808  return Operator("Activation")
12809  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
12810  .SetInput("data", data)
12811  .CreateSymbol();
12812 }
12813 
12882  Symbol gamma,
12883  Symbol beta,
12884  Symbol moving_mean,
12885  Symbol moving_var,
12886  double eps = 0.001,
12887  mx_float momentum = 0.9,
12888  bool fix_gamma = true,
12889  bool use_global_stats = false,
12890  bool output_mean_var = false,
12891  int axis = 1,
12892  bool cudnn_off = false) {
12893  return Operator("BatchNorm")
12894  .SetParam("eps", eps)
12895  .SetParam("momentum", momentum)
12896  .SetParam("fix_gamma", fix_gamma)
12897  .SetParam("use_global_stats", use_global_stats)
12898  .SetParam("output_mean_var", output_mean_var)
12899  .SetParam("axis", axis)
12900  .SetParam("cudnn_off", cudnn_off)
12901  .SetInput("data", data)
12902  .SetInput("gamma", gamma)
12903  .SetInput("beta", beta)
12904  .SetInput("moving_mean", moving_mean)
12905  .SetInput("moving_var", moving_var)
12906  .CreateSymbol();
12907 }
12908 
12974 inline Symbol CTCLoss(Symbol data,
12975  Symbol label,
12976  Symbol data_lengths,
12977  Symbol label_lengths,
12978  bool use_data_lengths = false,
12979  bool use_label_lengths = false,
12981  static const char *CTCLossBlankLabelValues[] = {
12982  "first",
12983  "last"
12984  };
12985  return Operator("CTCLoss")
12986  .SetParam("use_data_lengths", use_data_lengths)
12987  .SetParam("use_label_lengths", use_label_lengths)
12988  .SetParam("blank_label", CTCLossBlankLabelValues[int(blank_label)])
12989  .SetInput("data", data)
12990  .SetInput("label", label)
12991  .SetInput("data_lengths", data_lengths)
12992  .SetInput("label_lengths", label_lengths)
12993  .CreateSymbol();
12994 }
12995 
13093  Symbol weight,
13094  Symbol bias,
13095  Shape kernel,
13096  uint32_t num_filter,
13097  Shape stride = Shape(),
13098  Shape dilate = Shape(),
13099  Shape pad = Shape(),
13100  uint32_t num_group = 1,
13101  uint64_t workspace = 1024,
13102  bool no_bias = false,
13104  bool cudnn_off = false,
13106  static const char *ConvolutionCudnnTuneValues[] = {
13107  "None",
13108  "fastest",
13109  "limited_workspace",
13110  "off"
13111  };
13112  static const char *ConvolutionLayoutValues[] = {
13113  "None",
13114  "NCDHW",
13115  "NCHW",
13116  "NCW",
13117  "NDHWC",
13118  "NHWC"
13119  };
13120  return Operator("Convolution")
13121  .SetParam("kernel", kernel)
13122  .SetParam("num_filter", num_filter)
13123  .SetParam("stride", stride)
13124  .SetParam("dilate", dilate)
13125  .SetParam("pad", pad)
13126  .SetParam("num_group", num_group)
13127  .SetParam("workspace", workspace)
13128  .SetParam("no_bias", no_bias)
13129  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
13130  .SetParam("cudnn_off", cudnn_off)
13131  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
13132  .SetInput("data", data)
13133  .SetInput("weight", weight)
13134  .SetInput("bias", bias)
13135  .CreateSymbol();
13136 }
13137 
13152 inline Symbol UpSampling(const std::vector<Symbol>& data,
13153  int scale,
13154  UpSamplingSampleType sample_type,
13155  int num_args,
13156  int num_filter = 0,
13158  uint64_t workspace = 512) {
13159  static const char *UpSamplingSampleTypeValues[] = {
13160  "bilinear",
13161  "nearest"
13162  };
13163  static const char *UpSamplingMultiInputModeValues[] = {
13164  "concat",
13165  "sum"
13166  };
13167  return Operator("UpSampling")
13168  .SetParam("scale", scale)
13169  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
13170  .SetParam("num_args", num_args)
13171  .SetParam("num_filter", num_filter)
13172  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
13173  .SetParam("workspace", workspace)
13174 (data)
13175  .CreateSymbol();
13176 }
13177 
13223 inline Symbol Concat(const std::vector<Symbol>& data,
13224  int num_args,
13225  int dim = 1) {
13226  return Operator("Concat")
13227  .SetParam("num_args", num_args)
13228  .SetParam("dim", dim)
13229 (data)
13230  .CreateSymbol();
13231 }
13232 
13271  Symbol gamma,
13272  Symbol beta,
13273  int axis = -1,
13274  mx_float eps = 1e-05,
13275  bool output_mean_var = false) {
13276  return Operator("LayerNorm")
13277  .SetParam("axis", axis)
13278  .SetParam("eps", eps)
13279  .SetParam("output_mean_var", output_mean_var)
13280  .SetInput("data", data)
13281  .SetInput("gamma", gamma)
13282  .SetInput("beta", beta)
13283  .CreateSymbol();
13284 }
13285 
13312 inline Symbol LRN(Symbol data,
13313  uint32_t nsize,
13314  mx_float alpha = 0.0001,
13315  mx_float beta = 0.75,
13316  mx_float knorm = 2) {
13317  return Operator("LRN")
13318  .SetParam("nsize", nsize)
13319  .SetParam("alpha", alpha)
13320  .SetParam("beta", beta)
13321  .SetParam("knorm", knorm)
13322  .SetInput("data", data)
13323  .CreateSymbol();
13324 }
13325 
13365 inline Symbol Dropout(Symbol data,
13366  mx_float p = 0.5,
13368  Shape axes = Shape()) {
13369  static const char *DropoutModeValues[] = {
13370  "always",
13371  "training"
13372  };
13373  return Operator("Dropout")
13374  .SetParam("p", p)
13375  .SetParam("mode", DropoutModeValues[int(mode)])
13376  .SetParam("axes", axes)
13377  .SetInput("data", data)
13378  .CreateSymbol();
13379 }
13380 
13415  static const char *SoftmaxActivationModeValues[] = {
13416  "channel",
13417  "instance"
13418  };
13419  return Operator("SoftmaxActivation")
13420  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
13421  .SetInput("data", data)
13422  .CreateSymbol();
13423 }
13424 
13468  Symbol weight,
13469  Symbol bias,
13470  int num_hidden,
13471  bool no_bias = false,
13472  bool flatten = true) {
13473  return Operator("FullyConnected")
13474  .SetParam("num_hidden", num_hidden)
13475  .SetParam("no_bias", no_bias)
13476  .SetParam("flatten", flatten)
13477  .SetInput("data", data)
13478  .SetInput("weight", weight)
13479  .SetInput("bias", bias)
13480  .CreateSymbol();
13481 }
13482 
13578 inline Symbol Pad(Symbol data,
13579  PadMode mode,
13580  Shape pad_width,
13581  double constant_value = 0) {
13582  static const char *PadModeValues[] = {
13583  "constant",
13584  "edge",
13585  "reflect"
13586  };
13587  return Operator("Pad")
13588  .SetParam("mode", PadModeValues[int(mode)])
13589  .SetParam("pad_width", pad_width)
13590  .SetParam("constant_value", constant_value)
13591  .SetInput("data", data)
13592  .CreateSymbol();
13593 }
13594 
13625  Symbol gamma,
13627  mx_float slope = 0.25,
13628  mx_float lower_bound = 0.125,
13629  mx_float upper_bound = 0.334) {
13630  static const char *LeakyReLUActTypeValues[] = {
13631  "elu",
13632  "leaky",
13633  "prelu",
13634  "rrelu",
13635  "selu"
13636  };
13637  return Operator("LeakyReLU")
13638  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
13639  .SetParam("slope", slope)
13640  .SetParam("lower_bound", lower_bound)
13641  .SetParam("upper_bound", upper_bound)
13642  .SetInput("data", data)
13643  .SetInput("gamma", gamma)
13644  .CreateSymbol();
13645 }
13646 
13674 inline Symbol SwapAxis(Symbol data,
13675  uint32_t dim1 = 0,
13676  uint32_t dim2 = 0) {
13677  return Operator("SwapAxis")
13678  .SetParam("dim1", dim1)
13679  .SetParam("dim2", dim2)
13680  .SetInput("data", data)
13681  .CreateSymbol();
13682 }
13683 
13744  Symbol gamma,
13745  Symbol beta,
13746  mx_float eps = 0.001,
13747  mx_float momentum = 0.9,
13748  bool fix_gamma = true,
13749  bool use_global_stats = false,
13750  bool output_mean_var = false) {
13751  return Operator("BatchNorm_v1")
13752  .SetParam("eps", eps)
13753  .SetParam("momentum", momentum)
13754  .SetParam("fix_gamma", fix_gamma)
13755  .SetParam("use_global_stats", use_global_stats)
13756  .SetParam("output_mean_var", output_mean_var)
13757  .SetInput("data", data)
13758  .SetInput("gamma", gamma)
13759  .SetInput("beta", beta)
13760  .CreateSymbol();
13761 }
13762 
13800  Symbol label) {
13801  return Operator("softmax_cross_entropy")
13802  .SetInput("data", data)
13803  .SetInput("label", label)
13804  .CreateSymbol();
13805 }
13806 
13836  Symbol label,
13837  mx_float grad_scale = 1) {
13838  return Operator("LinearRegressionOutput")
13839  .SetParam("grad_scale", grad_scale)
13840  .SetInput("data", data)
13841  .SetInput("label", label)
13842  .CreateSymbol();
13843 }
13844 
13875  Symbol label,
13876  mx_float grad_scale = 1) {
13877  return Operator("MAERegressionOutput")
13878  .SetParam("grad_scale", grad_scale)
13879  .SetInput("data", data)
13880  .SetInput("label", label)
13881  .CreateSymbol();
13882 }
13883 
13920  Symbol label,
13921  mx_float grad_scale = 1) {
13922  return Operator("LogisticRegressionOutput")
13923  .SetParam("grad_scale", grad_scale)
13924  .SetInput("data", data)
13925  .SetInput("label", label)
13926  .CreateSymbol();
13927 }
13928 
13938  mx_float sparseness_target = 0.1,
13939  mx_float penalty = 0.001,
13940  mx_float momentum = 0.9) {
13941  return Operator("IdentityAttachKLSparseReg")
13942  .SetParam("sparseness_target", sparseness_target)
13943  .SetParam("penalty", penalty)
13944  .SetParam("momentum", momentum)
13945  .SetInput("data", data)
13946  .CreateSymbol();
13947 }
13948 
13977  Symbol grad,
13978  mx_float lr,
13979  mx_float wd = 0,
13980  mx_float rescale_grad = 1,
13981  mx_float clip_gradient = -1) {
13982  return Operator("signsgd_update")
13983  .SetParam("lr", lr)
13984  .SetParam("wd", wd)
13985  .SetParam("rescale_grad", rescale_grad)
13986  .SetParam("clip_gradient", clip_gradient)
13987  .SetInput("weight", weight)
13988  .SetInput("grad", grad)
13989  .CreateSymbol();
13990 }
13991 
14026  Symbol grad,
14027  Symbol mom,
14028  mx_float lr,
14029  mx_float momentum = 0,
14030  mx_float wd = 0,
14031  mx_float rescale_grad = 1,
14032  mx_float clip_gradient = -1,
14033  mx_float wd_lh = 0) {
14034  return Operator("signum_update")
14035  .SetParam("lr", lr)
14036  .SetParam("momentum", momentum)
14037  .SetParam("wd", wd)
14038  .SetParam("rescale_grad", rescale_grad)
14039  .SetParam("clip_gradient", clip_gradient)
14040  .SetParam("wd_lh", wd_lh)
14041  .SetInput("weight", weight)
14042  .SetInput("grad", grad)
14043  .SetInput("mom", mom)
14044  .CreateSymbol();
14045 }
14046 
14074 inline Symbol sgd_update(Symbol weight,
14075  Symbol grad,
14076  mx_float lr,
14077  mx_float wd = 0,
14078  mx_float rescale_grad = 1,
14079  mx_float clip_gradient = -1,
14080  bool lazy_update = true) {
14081  return Operator("sgd_update")
14082  .SetParam("lr", lr)
14083  .SetParam("wd", wd)
14084  .SetParam("rescale_grad", rescale_grad)
14085  .SetParam("clip_gradient", clip_gradient)
14086  .SetParam("lazy_update", lazy_update)
14087  .SetInput("weight", weight)
14088  .SetInput("grad", grad)
14089  .CreateSymbol();
14090 }
14091 
14136  Symbol grad,
14137  Symbol mom,
14138  mx_float lr,
14139  mx_float momentum = 0,
14140  mx_float wd = 0,
14141  mx_float rescale_grad = 1,
14142  mx_float clip_gradient = -1,
14143  bool lazy_update = true) {
14144  return Operator("sgd_mom_update")
14145  .SetParam("lr", lr)
14146  .SetParam("momentum", momentum)
14147  .SetParam("wd", wd)
14148  .SetParam("rescale_grad", rescale_grad)
14149  .SetParam("clip_gradient", clip_gradient)
14150  .SetParam("lazy_update", lazy_update)
14151  .SetInput("weight", weight)
14152  .SetInput("grad", grad)
14153  .SetInput("mom", mom)
14154  .CreateSymbol();
14155 }
14156 
14172  Symbol grad,
14173  Symbol weight32,
14174  mx_float lr,
14175  mx_float wd = 0,
14176  mx_float rescale_grad = 1,
14177  mx_float clip_gradient = -1,
14178  bool lazy_update = true) {
14179  return Operator("mp_sgd_update")
14180  .SetParam("lr", lr)
14181  .SetParam("wd", wd)
14182  .SetParam("rescale_grad", rescale_grad)
14183  .SetParam("clip_gradient", clip_gradient)
14184  .SetParam("lazy_update", lazy_update)
14185  .SetInput("weight", weight)
14186  .SetInput("grad", grad)
14187  .SetInput("weight32", weight32)
14188  .CreateSymbol();
14189 }
14190 
14208  Symbol grad,
14209  Symbol mom,
14210  Symbol weight32,
14211  mx_float lr,
14212  mx_float momentum = 0,
14213  mx_float wd = 0,
14214  mx_float rescale_grad = 1,
14215  mx_float clip_gradient = -1,
14216  bool lazy_update = true) {
14217  return Operator("mp_sgd_mom_update")
14218  .SetParam("lr", lr)
14219  .SetParam("momentum", momentum)
14220  .SetParam("wd", wd)
14221  .SetParam("rescale_grad", rescale_grad)
14222  .SetParam("clip_gradient", clip_gradient)
14223  .SetParam("lazy_update", lazy_update)
14224  .SetInput("weight", weight)
14225  .SetInput("grad", grad)
14226  .SetInput("mom", mom)
14227  .SetInput("weight32", weight32)
14228  .CreateSymbol();
14229 }
14230 
14265 inline Symbol ftml_update(Symbol weight,
14266  Symbol grad,
14267  Symbol d,
14268  Symbol v,
14269  Symbol z,
14270  mx_float lr,
14271  int t,
14272  mx_float beta1 = 0.6,
14273  mx_float beta2 = 0.999,
14274  double epsilon = 1e-08,
14275  mx_float wd = 0,
14276  mx_float rescale_grad = 1,
14277  mx_float clip_grad = -1) {
14278  return Operator("ftml_update")
14279  .SetParam("lr", lr)
14280  .SetParam("t", t)
14281  .SetParam("beta1", beta1)
14282  .SetParam("beta2", beta2)
14283  .SetParam("epsilon", epsilon)
14284  .SetParam("wd", wd)
14285  .SetParam("rescale_grad", rescale_grad)
14286  .SetParam("clip_grad", clip_grad)
14287  .SetInput("weight", weight)
14288  .SetInput("grad", grad)
14289  .SetInput("d", d)
14290  .SetInput("v", v)
14291  .SetInput("z", z)
14292  .CreateSymbol();
14293 }
14294 
14343 inline Symbol adam_update(Symbol weight,
14344  Symbol grad,
14345  Symbol mean,
14346  Symbol var,
14347  mx_float lr,
14348  mx_float beta1 = 0.9,
14349  mx_float beta2 = 0.999,
14350  mx_float epsilon = 1e-08,
14351  mx_float wd = 0,
14352  mx_float rescale_grad = 1,
14353  mx_float clip_gradient = -1,
14354  bool lazy_update = true) {
14355  return Operator("adam_update")
14356  .SetParam("lr", lr)
14357  .SetParam("beta1", beta1)
14358  .SetParam("beta2", beta2)
14359  .SetParam("epsilon", epsilon)
14360  .SetParam("wd", wd)
14361  .SetParam("rescale_grad", rescale_grad)
14362  .SetParam("clip_gradient", clip_gradient)
14363  .SetParam("lazy_update", lazy_update)
14364  .SetInput("weight", weight)
14365  .SetInput("grad", grad)
14366  .SetInput("mean", mean)
14367  .SetInput("var", var)
14368  .CreateSymbol();
14369 }
14370 
14424  Symbol grad,
14425  Symbol n,
14426  mx_float lr,
14427  mx_float gamma1 = 0.95,
14428  mx_float epsilon = 1e-08,
14429  mx_float wd = 0,
14430  mx_float rescale_grad = 1,
14431  mx_float clip_gradient = -1,
14432  mx_float clip_weights = -1) {
14433  return Operator("rmsprop_update")
14434  .SetParam("lr", lr)
14435  .SetParam("gamma1", gamma1)
14436  .SetParam("epsilon", epsilon)
14437  .SetParam("wd", wd)
14438  .SetParam("rescale_grad", rescale_grad)
14439  .SetParam("clip_gradient", clip_gradient)
14440  .SetParam("clip_weights", clip_weights)
14441  .SetInput("weight", weight)
14442  .SetInput("grad", grad)
14443  .SetInput("n", n)
14444  .CreateSymbol();
14445 }
14446 
14492  Symbol grad,
14493  Symbol n,
14494  Symbol g,
14495  Symbol delta,
14496  mx_float lr,
14497  mx_float gamma1 = 0.95,
14498  mx_float gamma2 = 0.9,
14499  mx_float epsilon = 1e-08,
14500  mx_float wd = 0,
14501  mx_float rescale_grad = 1,
14502  mx_float clip_gradient = -1,
14503  mx_float clip_weights = -1) {
14504  return Operator("rmspropalex_update")
14505  .SetParam("lr", lr)
14506  .SetParam("gamma1", gamma1)
14507  .SetParam("gamma2", gamma2)
14508  .SetParam("epsilon", epsilon)
14509  .SetParam("wd", wd)
14510  .SetParam("rescale_grad", rescale_grad)
14511  .SetParam("clip_gradient", clip_gradient)
14512  .SetParam("clip_weights", clip_weights)
14513  .SetInput("weight", weight)
14514  .SetInput("grad", grad)
14515  .SetInput("n", n)
14516  .SetInput("g", g)
14517  .SetInput("delta", delta)
14518  .CreateSymbol();
14519 }
14520 
14559 inline Symbol ftrl_update(Symbol weight,
14560  Symbol grad,
14561  Symbol z,
14562  Symbol n,
14563  mx_float lr,
14564  mx_float lamda1 = 0.01,
14565  mx_float beta = 1,
14566  mx_float wd = 0,
14567  mx_float rescale_grad = 1,
14568  mx_float clip_gradient = -1) {
14569  return Operator("ftrl_update")
14570  .SetParam("lr", lr)
14571  .SetParam("lamda1", lamda1)
14572  .SetParam("beta", beta)
14573  .SetParam("wd", wd)
14574  .SetParam("rescale_grad", rescale_grad)
14575  .SetParam("clip_gradient", clip_gradient)
14576  .SetInput("weight", weight)
14577  .SetInput("grad", grad)
14578  .SetInput("z", z)
14579  .SetInput("n", n)
14580  .CreateSymbol();
14581 }
14582 
14654  int num_outputs,
14655  int axis = 1,
14656  bool squeeze_axis = false) {
14657  return Operator("SliceChannel")
14658  .SetParam("num_outputs", num_outputs)
14659  .SetParam("axis", axis)
14660  .SetParam("squeeze_axis", squeeze_axis)
14661  .SetInput("data", data)
14662  .CreateSymbol();
14663 }
14664 
14715  Symbol gamma,
14716  Symbol beta,
14717  mx_float eps = 0.001) {
14718  return Operator("InstanceNorm")
14719  .SetParam("eps", eps)
14720  .SetInput("data", data)
14721  .SetInput("gamma", gamma)
14722  .SetInput("beta", beta)
14723  .CreateSymbol();
14724 }
14725 
14736  GridGeneratorTransformType transform_type,
14737  Shape target_shape = Shape(0,0)) {
14738  static const char *GridGeneratorTransformTypeValues[] = {
14739  "affine",
14740  "warp"
14741  };
14742  return Operator("GridGenerator")
14743  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
14744  .SetParam("target_shape", target_shape)
14745  .SetInput("data", data)
14746  .CreateSymbol();
14747 }
14748 
14800  Shape kernel = Shape(),
14802  bool global_pool = false,
14804  Shape stride = Shape(),
14805  Shape pad = Shape()) {
14806  static const char *Pooling_v1PoolTypeValues[] = {
14807  "avg",
14808  "max",
14809  "sum"
14810  };
14811  static const char *Pooling_v1PoolingConventionValues[] = {
14812  "full",
14813  "valid"
14814  };
14815  return Operator("Pooling_v1")
14816  .SetParam("kernel", kernel)
14817  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
14818  .SetParam("global_pool", global_pool)
14819  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
14820  .SetParam("stride", stride)
14821  .SetParam("pad", pad)
14822  .SetInput("data", data)
14823  .CreateSymbol();
14824 }
14825 
14896 inline Symbol RNN(Symbol data,
14897  Symbol parameters,
14898  Symbol state,
14899  Symbol state_cell,
14900  uint32_t state_size,
14901  uint32_t num_layers,
14902  RNNMode mode,
14903  bool bidirectional = false,
14904  mx_float p = 0,
14905  bool state_outputs = false,
14906  dmlc::optional<int> projection_size = dmlc::optional<int>(),
14907  dmlc::optional<double> lstm_state_clip_min = dmlc::optional<double>(),
14908  dmlc::optional<double> lstm_state_clip_max = dmlc::optional<double>(),
14909  bool lstm_state_clip_nan = false) {
14910  static const char *RNNModeValues[] = {
14911  "gru",
14912  "lstm",
14913  "rnn_relu",
14914  "rnn_tanh"
14915  };
14916  return Operator("RNN")
14917  .SetParam("state_size", state_size)
14918  .SetParam("num_layers", num_layers)
14919  .SetParam("mode", RNNModeValues[int(mode)])
14920  .SetParam("bidirectional", bidirectional)
14921  .SetParam("p", p)
14922  .SetParam("state_outputs", state_outputs)
14923  .SetParam("projection_size", projection_size)
14924  .SetParam("lstm_state_clip_min", lstm_state_clip_min)
14925  .SetParam("lstm_state_clip_max", lstm_state_clip_max)
14926  .SetParam("lstm_state_clip_nan", lstm_state_clip_nan)
14927  .SetInput("data", data)
14928  .SetInput("parameters", parameters)
14929  .SetInput("state", state)
14930  .SetInput("state_cell", state_cell)
14931  .CreateSymbol();
14932 }
14933 
14964  Symbol weight,
14965  Symbol bias,
14966  Shape kernel,
14967  uint32_t num_filter,
14968  Shape stride = Shape(),
14969  Shape dilate = Shape(),
14970  Shape pad = Shape(),
14971  uint32_t num_group = 1,
14972  uint64_t workspace = 1024,
14973  bool no_bias = false,
14975  bool cudnn_off = false,
14977  static const char *Convolution_v1CudnnTuneValues[] = {
14978  "None",
14979  "fastest",
14980  "limited_workspace",
14981  "off"
14982  };
14983  static const char *Convolution_v1LayoutValues[] = {
14984  "None",
14985  "NCDHW",
14986  "NCHW",
14987  "NDHWC",
14988  "NHWC"
14989  };
14990  return Operator("Convolution_v1")
14991  .SetParam("kernel", kernel)
14992  .SetParam("num_filter", num_filter)
14993  .SetParam("stride", stride)
14994  .SetParam("dilate", dilate)
14995  .SetParam("pad", pad)
14996  .SetParam("num_group", num_group)
14997  .SetParam("workspace", workspace)
14998  .SetParam("no_bias", no_bias)
14999  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
15000  .SetParam("cudnn_off", cudnn_off)
15001  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
15002  .SetInput("data", data)
15003  .SetInput("weight", weight)
15004  .SetInput("bias", bias)
15005  .CreateSymbol();
15006 }
15007 
15027 inline Symbol Crop(const std::vector<Symbol>& data,
15028  int num_args,
15029  Shape offset = Shape(0,0),
15030  Shape h_w = Shape(0,0),
15031  bool center_crop = false) {
15032  return Operator("Crop")
15033  .SetParam("num_args", num_args)
15034  .SetParam("offset", offset)
15035  .SetParam("h_w", h_w)
15036  .SetParam("center_crop", center_crop)
15037 (data)
15038  .CreateSymbol();
15039 }
15040 
15117  Symbol sequence_length,
15118  bool use_sequence_length = false,
15119  int axis = 0) {
15120  return Operator("SequenceReverse")
15121  .SetParam("use_sequence_length", use_sequence_length)
15122  .SetParam("axis", axis)
15123  .SetInput("data", data)
15124  .SetInput("sequence_length", sequence_length)
15125  .CreateSymbol();
15126 }
15127 
15139  Symbol loc,
15140  SpatialTransformerTransformType transform_type,
15141  SpatialTransformerSamplerType sampler_type,
15142  Shape target_shape = Shape(0,0),
15143  dmlc::optional<bool> cudnn_off = dmlc::optional<bool>()) {
15144  static const char *SpatialTransformerTransformTypeValues[] = {
15145  "affine"
15146  };
15147  static const char *SpatialTransformerSamplerTypeValues[] = {
15148  "bilinear"
15149  };
15150  return Operator("SpatialTransformer")
15151  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
15152  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
15153  .SetParam("target_shape", target_shape)
15154  .SetParam("cudnn_off", cudnn_off)
15155  .SetInput("data", data)
15156  .SetInput("loc", loc)
15157  .CreateSymbol();
15158 }
15159 
15255  Symbol label,
15256  mx_float grad_scale = 1,
15257  mx_float ignore_label = -1,
15258  bool multi_output = false,
15259  bool use_ignore = false,
15260  bool preserve_shape = false,
15262  bool out_grad = false,
15263  mx_float smooth_alpha = 0) {
15264  static const char *SoftmaxOutputNormalizationValues[] = {
15265  "batch",
15266  "null",
15267  "valid"
15268  };
15269  return Operator("SoftmaxOutput")
15270  .SetParam("grad_scale", grad_scale)
15271  .SetParam("ignore_label", ignore_label)
15272  .SetParam("multi_output", multi_output)
15273  .SetParam("use_ignore", use_ignore)
15274  .SetParam("preserve_shape", preserve_shape)
15275  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
15276  .SetParam("out_grad", out_grad)
15277  .SetParam("smooth_alpha", smooth_alpha)
15278  .SetInput("data", data)
15279  .SetInput("label", label)
15280  .CreateSymbol();
15281 }
15282 
15309 inline Symbol Softmax(Symbol data,
15310  mx_float grad_scale = 1,
15311  mx_float ignore_label = -1,
15312  bool multi_output = false,
15313  bool use_ignore = false,
15314  bool preserve_shape = false,
15316  bool out_grad = false,
15317  mx_float smooth_alpha = 0) {
15318  static const char *SoftmaxNormalizationValues[] = {
15319  "batch",
15320  "null",
15321  "valid"
15322  };
15323  return Operator("Softmax")
15324  .SetParam("grad_scale", grad_scale)
15325  .SetParam("ignore_label", ignore_label)
15326  .SetParam("multi_output", multi_output)
15327  .SetParam("use_ignore", use_ignore)
15328  .SetParam("preserve_shape", preserve_shape)
15329  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
15330  .SetParam("out_grad", out_grad)
15331  .SetParam("smooth_alpha", smooth_alpha)
15332  .SetInput("data", data)
15333  .CreateSymbol();
15334 }
15335 
15417  Symbol grid,
15418  dmlc::optional<bool> cudnn_off = dmlc::optional<bool>()) {
15419  return Operator("BilinearSampler")
15420  .SetParam("cudnn_off", cudnn_off)
15421  .SetInput("data", data)
15422  .SetInput("grid", grid)
15423  .CreateSymbol();
15424 }
15425 
15482  Symbol rois,
15483  Shape pooled_size,
15484  mx_float spatial_scale) {
15485  return Operator("ROIPooling")
15486  .SetParam("pooled_size", pooled_size)
15487  .SetParam("spatial_scale", spatial_scale)
15488  .SetInput("data", data)
15489  .SetInput("rois", rois)
15490  .CreateSymbol();
15491 }
15492 
15548  Symbol sequence_length,
15549  bool use_sequence_length = false,
15550  int axis = 0) {
15551  return Operator("SequenceLast")
15552  .SetParam("use_sequence_length", use_sequence_length)
15553  .SetParam("axis", axis)
15554  .SetInput("data", data)
15555  .SetInput("sequence_length", sequence_length)
15556  .CreateSymbol();
15557 }
15558 
15621  mx_float eps = 1e-10,
15623  static const char *L2NormalizationModeValues[] = {
15624  "channel",
15625  "instance",
15626  "spatial"
15627  };
15628  return Operator("L2Normalization")
15629  .SetParam("eps", eps)
15630  .SetParam("mode", L2NormalizationModeValues[int(mode)])
15631  .SetInput("data", data)
15632  .CreateSymbol();
15633 }
15634 
15668 inline Symbol MakeLoss(Symbol data,
15669  mx_float grad_scale = 1,
15670  mx_float valid_thresh = 0,
15672  static const char *MakeLossNormalizationValues[] = {
15673  "batch",
15674  "null",
15675  "valid"
15676  };
15677  return Operator("MakeLoss")
15678  .SetParam("grad_scale", grad_scale)
15679  .SetParam("valid_thresh", valid_thresh)
15680  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
15681  .SetInput("data", data)
15682  .CreateSymbol();
15683 }
15684 
15700  Symbol label,
15701  mx_float margin = 1,
15702  mx_float regularization_coefficient = 1,
15703  bool use_linear = false) {
15704  return Operator("SVMOutput")
15705  .SetParam("margin", margin)
15706  .SetParam("regularization_coefficient", regularization_coefficient)
15707  .SetParam("use_linear", use_linear)
15708  .SetInput("data", data)
15709  .SetInput("label", label)
15710  .CreateSymbol();
15711 }
15712 
15762  Symbol data2,
15763  uint32_t kernel_size = 1,
15764  uint32_t max_displacement = 1,
15765  uint32_t stride1 = 1,
15766  uint32_t stride2 = 1,
15767  uint32_t pad_size = 0,
15768  bool is_multiply = true) {
15769  return Operator("Correlation")
15770  .SetParam("kernel_size", kernel_size)
15771  .SetParam("max_displacement", max_displacement)
15772  .SetParam("stride1", stride1)
15773  .SetParam("stride2", stride2)
15774  .SetParam("pad_size", pad_size)
15775  .SetParam("is_multiply", is_multiply)
15776  .SetInput("data1", data1)
15777  .SetInput("data2", data2)
15778  .CreateSymbol();
15779 }
15780 
15859  Symbol sequence_length,
15860  bool use_sequence_length = false,
15861  mx_float value = 0,
15862  int axis = 0) {
15863  return Operator("SequenceMask")
15864  .SetParam("use_sequence_length", use_sequence_length)
15865  .SetParam("value", value)
15866  .SetParam("axis", axis)
15867  .SetInput("data", data)
15868  .SetInput("sequence_length", sequence_length)
15869  .CreateSymbol();
15870 }
15871 
15880  Symbol rhs) {
15881  return Operator("choose_element_0index")
15882  .SetInput("lhs", lhs)
15883  .SetInput("rhs", rhs)
15884  .CreateSymbol();
15885 }
15886 
15896  Symbol mhs,
15897  Symbol rhs) {
15898  return Operator("fill_element_0index")
15899  .SetInput("lhs", lhs)
15900  .SetInput("mhs", mhs)
15901  .SetInput("rhs", rhs)
15902  .CreateSymbol();
15903 }
15904 
15905 } //namespace cpp
15906 } //namespace mxnet
15907 #endif // MXNET_CPP_OP_H_
Symbol Convolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, ConvolutionCudnnTune cudnn_tune=ConvolutionCudnnTune::kNone, bool cudnn_off=false, ConvolutionLayout layout=ConvolutionLayout::kNone)
Definition: op.h:5253
Symbol Pooling(const std::string &symbol_name, Symbol data, Shape kernel=Shape(), PoolingPoolType pool_type=PoolingPoolType::kMax, bool global_pool=false, bool cudnn_off=false, PoolingPoolingConvention pooling_convention=PoolingPoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape(), dmlc::optional< int > p_value=dmlc::optional< int >(), dmlc::optional< bool > count_include_pad=dmlc::optional< bool >())
Definition: op.h:4636
Symbol mp_sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, Symbol weight32, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, bool lazy_update=true)
Definition: op.h:6461
Symbol fix(const std::string &symbol_name, Symbol data)
Definition: op.h:2147
Symbol Crop(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, Shape offset=Shape(0, 0), Shape h_w=Shape(0, 0), bool center_crop=false)
Definition: op.h:7364
Symbol broadcast_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1489
Symbol arcsin(const std::string &symbol_name, Symbol data)
Definition: op.h:4297
Symbol FullyConnected(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, int num_hidden, bool no_bias=false, bool flatten=true)
Definition: op.h:5672
Symbol arccosh(const std::string &symbol_name, Symbol data)
Definition: op.h:4520
Symbol arctan(const std::string &symbol_name, Symbol data)
Definition: op.h:4350
Symbol SwapAxis(const std::string &symbol_name, Symbol data, uint32_t dim1=0, uint32_t dim2=0)
Definition: op.h:5904
Symbol cast_storage(const std::string &symbol_name, Symbol data, Cast_storageStype stype)
Definition: op.h:4181
Symbol add_n(const std::string &symbol_name, const std::vector< Symbol > &args)
Definition: op.h:1030
Symbol log1p(const std::string &symbol_name, Symbol data)
Definition: op.h:2423
SoftmaxActivationMode
Definition: op.h:5578
Symbol SpatialTransformer(const std::string &symbol_name, Symbol data, Symbol loc, SpatialTransformerTransformType transform_type, SpatialTransformerSamplerType sampler_type, Shape target_shape=Shape(0, 0), dmlc::optional< bool > cudnn_off=dmlc::optional< bool >())
Definition: op.h:7491
Symbol argsort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true, ArgsortDtype dtype=ArgsortDtype::kFloat32)
Definition: op.h:3161
Symbol slice(const std::string &symbol_name, Symbol data, Shape begin, Shape end, Shape step=Shape())
Definition: op.h:481
Symbol diag(const std::string &symbol_name, Symbol data, int k=0, int axis1=0, int axis2=1)
Definition: op.h:4034
Symbol exp(const std::string &symbol_name, Symbol data)
Definition: op.h:2334
Symbol transpose(const std::string &symbol_name, Symbol data, Shape axes=Shape())
Definition: op.h:393
Symbol RNN(const std::string &symbol_name, Symbol data, Symbol parameters, Symbol state, Symbol state_cell, uint32_t state_size, uint32_t num_layers, RNNMode mode, bool bidirectional=false, mx_float p=0, bool state_outputs=false, dmlc::optional< int > projection_size=dmlc::optional< int >(), dmlc::optional< double > lstm_state_clip_min=dmlc::optional< double >(), dmlc::optional< double > lstm_state_clip_max=dmlc::optional< double >(), bool lstm_state_clip_nan=false)
Definition: op.h:7202
Symbol clip(const std::string &symbol_name, Symbol data, mx_float a_min, mx_float a_max)
Definition: op.h:645
Symbol elemwise_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3275
Symbol Embedding(const std::string &symbol_name, Symbol data, Symbol weight, int input_dim, int output_dim, EmbeddingDtype dtype=EmbeddingDtype::kFloat32, bool sparse_grad=false)
Definition: op.h:3359
Symbol ROIPooling(const std::string &symbol_name, Symbol data, Symbol rois, Shape pooled_size, mx_float spatial_scale)
Definition: op.h:7858
Symbol broadcast_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1524
Convolution_v1Layout
Definition: op.h:7260
Symbol argmin(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1110
Symbol dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false, DotForwardStype forward_stype=DotForwardStype::kNone)
Definition: op.h:1303
Symbol topk(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), int k=1, TopkRetTyp ret_typ=TopkRetTyp::kIndices, bool is_ascend=false, TopkDtype dtype=TopkDtype::kFloat32)
Definition: op.h:3045
Symbol SequenceReverse(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, int axis=0)
Definition: op.h:7455
Symbol broadcast_lesser(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3837
Symbol fill_element_0index(const std::string &symbol_name, Symbol lhs, Symbol mhs, Symbol rhs)
Definition: op.h:8306
Symbol Convolution_v1(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, Convolution_v1CudnnTune cudnn_tune=Convolution_v1CudnnTune::kNone, bool cudnn_off=false, Convolution_v1Layout layout=Convolution_v1Layout::kNone)
Definition: op.h:7298
Symbol broadcast_not_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3744
TakeMode
Definition: op.h:3389
Symbol SequenceLast(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, int axis=0)
Definition: op.h:7926
Symbol ftrl_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol z, Symbol n, mx_float lr, mx_float lamda1=0.01, mx_float beta=1, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:6823
Symbol reciprocal(const std::string &symbol_name, Symbol data)
Definition: op.h:1919
TopkRetTyp
Definition: op.h:2985
namespace of mxnet
Definition: base.h:118
Symbol reshape_like(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1770
Pooling_v1PoolingConvention
Definition: op.h:7038
Symbol broadcast_lesser_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3868
Operator & SetInput(const std::string &name, Symbol symbol)
add an input symbol
Symbol InstanceNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001)
Definition: op.h:6982
Symbol sign(const std::string &symbol_name, Symbol data)
Definition: op.h:1973
GridGeneratorTransformType
Definition: op.h:6998
Cast_storageStype
Definition: op.h:4130
Symbol log_softmax(const std::string &symbol_name, Symbol data, int axis=-1, dmlc::optional< double > temperature=dmlc::optional< double >())
Definition: op.h:4782
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
Symbol ones_like(const std::string &symbol_name, Symbol data)
Definition: op.h:1001
RNNMode
Definition: op.h:7124
PadMode
Definition: op.h:5692
Symbol smooth_l1(const std::string &symbol_name, Symbol data, mx_float scalar)
Definition: op.h:4119
Symbol where(const std::string &symbol_name, Symbol condition, Symbol x, Symbol y)
Definition: op.h:4082
Symbol space_to_depth(const std::string &symbol_name, Symbol data, int block_size)
Definition: op.h:945
Symbol expm1(const std::string &symbol_name, Symbol data)
Definition: op.h:2448
Symbol elemwise_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3200
PoolingPoolType
Definition: op.h:4553
Symbol relu(const std::string &symbol_name, Symbol data)
Definition: op.h:1583
Symbol reverse(const std::string &symbol_name, Symbol data, Shape axis)
Definition: op.h:778
Symbol rsqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:2233
Symbol mp_sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol weight32, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, bool lazy_update=true)
Definition: op.h:6423
Symbol batch_dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false, Batch_dotForwardStype forward_stype=Batch_dotForwardStype::kNone)
Definition: op.h:1360
SpatialTransformerTransformType
Definition: op.h:7470
ActivationActType
Definition: op.h:4895
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:2207
Symbol Softmax(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxNormalization normalization=SoftmaxNormalization::kNull, bool out_grad=false, mx_float smooth_alpha=0)
Definition: op.h:7682
Symbol rint(const std::string &symbol_name, Symbol data)
Definition: op.h:2031
Symbol IdentityAttachKLSparseReg(const std::string &symbol_name, Symbol data, mx_float sparseness_target=0.1, mx_float penalty=0.001, mx_float momentum=0.9)
Definition: op.h:6179
Symbol sinh(const std::string &symbol_name, Symbol data)
Definition: op.h:4428
Symbol scatter_nd(const std::string &symbol_name, Symbol data, Symbol indices, Shape shape)
Definition: op.h:3680
Symbol broadcast_greater_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3806
Symbol LRN(const std::string &symbol_name, Symbol data, uint32_t nsize, mx_float alpha=0.0001, mx_float beta=0.75, mx_float knorm=2)
Definition: op.h:5496
Symbol LayerNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, int axis=-1, mx_float eps=1e-05, bool output_mean_var=false)
Definition: op.h:5452
Symbol UpSampling(const std::string &symbol_name, const std::vector< Symbol > &data, int scale, UpSamplingSampleType sample_type, int num_args, int num_filter=0, UpSamplingMultiInputMode multi_input_mode=UpSamplingMultiInputMode::kConcat, uint64_t workspace=512)
Definition: op.h:5330
Symbol arcsinh(const std::string &symbol_name, Symbol data)
Definition: op.h:4500
Symbol MAERegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:6112
Symbol SliceChannel(const std::string &symbol_name, Symbol data, int num_outputs, int axis=1, bool squeeze_axis=false)
Definition: op.h:6919
PoolingPoolingConvention
Definition: op.h:4562
Symbol broadcast_minimum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:180
ArgsortDtype
Definition: op.h:3121
Symbol broadcast_maximum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:147
Symbol Cast(const std::string &symbol_name, Symbol data, CastDtype dtype)
Definition: op.h:1864
Symbol erf(const std::string &symbol_name, Symbol data)
Definition: op.h:2284
DeconvolutionLayout
Definition: op.h:4804
Symbol trunc(const std::string &symbol_name, Symbol data)
Definition: op.h:2119
Pooling_v1PoolType
Definition: op.h:7030
Symbol round(const std::string &symbol_name, Symbol data)
Definition: op.h:2000
Symbol adam_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mean, Symbol var, mx_float lr, mx_float beta1=0.9, mx_float beta2=0.999, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, bool lazy_update=true)
Definition: op.h:6601
Symbol Dropout(const std::string &symbol_name, Symbol data, mx_float p=0.5, DropoutMode mode=DropoutMode::kTraining, Shape axes=Shape())
Definition: op.h:5558
Symbol squeeze(const std::string &symbol_name, const std::vector< Symbol > &data, dmlc::optional< Shape > axis=dmlc::optional< Shape >())
Definition: op.h:843
TopkDtype
Definition: op.h:2994
Symbol broadcast_logical_or(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3930
Symbol khatri_rao(const std::string &symbol_name, const std::vector< Symbol > &args)
Definition: op.h:62
Symbol cos(const std::string &symbol_name, Symbol data)
Definition: op.h:4240
Symbol L2Normalization(const std::string &symbol_name, Symbol data, mx_float eps=1e-10, L2NormalizationMode mode=L2NormalizationMode::kInstance)
Definition: op.h:8009
Symbol max(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2756
Symbol Correlation(const std::string &symbol_name, Symbol data1, Symbol data2, uint32_t kernel_size=1, uint32_t max_displacement=1, uint32_t stride1=1, uint32_t stride2=1, uint32_t pad_size=0, bool is_multiply=true)
Definition: op.h:8166
CTCLossBlankLabel
Definition: op.h:5041
Symbol zeros_like(const std::string &symbol_name, Symbol data)
Definition: op.h:977
EmbeddingDtype
Definition: op.h:3286
Symbol broadcast_mod(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1555
Symbol cbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:2263
Symbol broadcast_like(const std::string &symbol_name, Symbol lhs, Symbol rhs, dmlc::optional< Shape > lhs_axes=dmlc::optional< Shape >(), dmlc::optional< Shape > rhs_axes=dmlc::optional< Shape >())
Definition: op.h:2911
operator helper functions
Symbol broadcast_logical_and(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3899
Symbol logical_not(const std::string &symbol_name, Symbol data)
Definition: op.h:2502
Symbol tanh(const std::string &symbol_name, Symbol data)
Definition: op.h:4476
Symbol broadcast_to(const std::string &symbol_name, Symbol data, Shape shape=Shape())
Definition: op.h:2874
Symbol elemwise_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3228
Symbol BilinearSampler(const std::string &symbol_name, Symbol data, Symbol grid, dmlc::optional< bool > cudnn_off=dmlc::optional< bool >())
Definition: op.h:7791
DropoutMode
Definition: op.h:5513
Symbol norm(const std::string &symbol_name, Symbol data, int ord=2, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false)
Definition: op.h:2967
Symbol MakeLoss(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float valid_thresh=0, MakeLossNormalization normalization=MakeLossNormalization::kNull)
Definition: op.h:8069
Symbol log(const std::string &symbol_name, Symbol data)
Definition: op.h:2355
Symbol broadcast_logical_xor(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3961
Symbol sigmoid(const std::string &symbol_name, Symbol data)
Definition: op.h:1605
CastDtype
Definition: op.h:1835
DotForwardStype
Definition: op.h:1238
ConvolutionLayout
Definition: op.h:5147
Symbol LogisticRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:6159
Symbol gamma(const std::string &symbol_name, Symbol data)
Definition: op.h:2466
Symbol sin(const std::string &symbol_name, Symbol data)
Definition: op.h:4216
UpSamplingMultiInputMode
Definition: op.h:5310
Symbol CreateSymbol(const std::string &name="")
create a Symbol from the current operator
Symbol CTCLoss(const std::string &symbol_name, Symbol data, Symbol label, Symbol data_lengths, Symbol label_lengths, bool use_data_lengths=false, bool use_label_lengths=false, CTCLossBlankLabel blank_label=CTCLossBlankLabel::kFirst)
Definition: op.h:5112
Symbol elemwise_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3255
Symbol softmax(const std::string &symbol_name, Symbol data, int axis=-1, dmlc::optional< double > temperature=dmlc::optional< double >())
Definition: op.h:4704
Batch_dotForwardStype
Definition: op.h:1328
SpatialTransformerSamplerType
Definition: op.h:7476
Symbol Pad(const std::string &symbol_name, Symbol data, PadMode mode, Shape pad_width, double constant_value=0)
Definition: op.h:5794
Symbol square(const std::string &symbol_name, Symbol data)
Definition: op.h:2177
Symbol LeakyReLU(const std::string &symbol_name, Symbol data, Symbol gamma, LeakyReLUActType act_type=LeakyReLUActType::kLeaky, mx_float slope=0.25, mx_float lower_bound=0.125, mx_float upper_bound=0.334)
Definition: op.h:5852
One_hotDtype
Definition: op.h:3511
Symbol nansum(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2680
UpSamplingSampleType
Definition: op.h:5302
Symbol rmsprop_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, mx_float lr, mx_float gamma1=0.95, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:6683
Symbol make_loss(const std::string &symbol_name, Symbol data)
Definition: op.h:1729
Symbol SoftmaxActivation(const std::string &symbol_name, Symbol data, SoftmaxActivationMode mode=SoftmaxActivationMode::kInstance)
Definition: op.h:5616
Symbol broadcast_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3713
Symbol nanprod(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2719
Symbol Deconvolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), Shape adj=Shape(), Shape target_shape=Shape(), uint32_t num_group=1, uint64_t workspace=512, bool no_bias=true, DeconvolutionCudnnTune cudnn_tune=DeconvolutionCudnnTune::kNone, bool cudnn_off=false, DeconvolutionLayout layout=DeconvolutionLayout::kNone)
Definition: op.h:4842
Symbol broadcast_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1413
Operator & SetParam(const std::string &name, const T &value)
set config parameters
Definition: operator.h:58
Symbol tan(const std::string &symbol_name, Symbol data)
Definition: op.h:4268
Convolution_v1CudnnTune
Definition: op.h:7250
Symbol repeat(const std::string &symbol_name, Symbol data, int repeats, dmlc::optional< int > axis=dmlc::optional< int >())
Definition: op.h:690
Symbol slice_axis(const std::string &symbol_name, Symbol data, int axis, int begin, dmlc::optional< int > end)
Definition: op.h:526
Symbol expand_dims(const std::string &symbol_name, Symbol data, int axis)
Definition: op.h:417
Symbol arctanh(const std::string &symbol_name, Symbol data)
Definition: op.h:4544
Symbol softmax_cross_entropy(const std::string &symbol_name, Symbol data, Symbol label)
Definition: op.h:6033
Symbol broadcast_axis(const std::string &symbol_name, Symbol data, Shape axis=Shape(), Shape size=Shape())
Definition: op.h:2835
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:1946
Symbol cosh(const std::string &symbol_name, Symbol data)
Definition: op.h:4450
Symbol sort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Definition: op.h:3107
Symbol gather_nd(const std::string &symbol_name, Symbol data, Symbol indices)
Definition: op.h:3620
Symbol slice_like(const std::string &symbol_name, Symbol data, Symbol shape_like, Shape axes=Shape())
Definition: op.h:600
Symbol Custom(const std::string &symbol_name, const std::vector< Symbol > &data, const std::string &op_type)
Definition: op.h:84
Symbol softmin(const std::string &symbol_name, Symbol data, int axis=-1, dmlc::optional< double > temperature=dmlc::optional< double >())
Definition: op.h:4748
Symbol Pooling_v1(const std::string &symbol_name, Symbol data, Shape kernel=Shape(), Pooling_v1PoolType pool_type=Pooling_v1PoolType::kMax, bool global_pool=false, Pooling_v1PoolingConvention pooling_convention=Pooling_v1PoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:7094
Symbol broadcast_hypot(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:219
Symbol BatchNorm_v1(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false)
Definition: op.h:5975
Symbol Activation(const std::string &symbol_name, Symbol data, ActivationActType act_type)
Definition: op.h:4922
float mx_float
manually define float
Definition: c_api.h:60
Symbol SVMOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float margin=1, mx_float regularization_coefficient=1, bool use_linear=false)
Definition: op.h:8102
Symbol radians(const std::string &symbol_name, Symbol data)
Definition: op.h:4402
Symbol Concat(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int dim=1)
Definition: op.h:5403
L2NormalizationMode
Definition: op.h:7941
Symbol SequenceMask(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, mx_float value=0, int axis=0)
Definition: op.h:8265
Symbol stack(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int axis=0)
Definition: op.h:810
Symbol floor(const std::string &symbol_name, Symbol data)
Definition: op.h:2089
Symbol broadcast_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1454
Symbol take(const std::string &symbol_name, Symbol a, Symbol indices, int axis=0, TakeMode mode=TakeMode::kClip)
Definition: op.h:3454
Symbol ceil(const std::string &symbol_name, Symbol data)
Definition: op.h:2060
Symbol gammaln(const std::string &symbol_name, Symbol data)
Definition: op.h:2484
Symbol tile(const std::string &symbol_name, Symbol data, Shape reps)
Definition: op.h:746
Symbol min(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2793
Symbol signum_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float wd_lh=0)
Definition: op.h:6271
Symbol depth_to_space(const std::string &symbol_name, Symbol data, int block_size)
Definition: op.h:893
Symbol rmspropalex_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, Symbol g, Symbol delta, mx_float lr, mx_float gamma1=0.95, mx_float gamma2=0.9, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:6753
Symbol ftml_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol d, Symbol v, Symbol z, mx_float lr, int t, mx_float beta1=0.6, mx_float beta2=0.999, double epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_grad=-1)
Definition: op.h:6521
SoftmaxNormalization
Definition: op.h:7649
DeconvolutionCudnnTune
Definition: op.h:4795
ConvolutionCudnnTune
Definition: op.h:5137
Symbol prod(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2641
Symbol pick(const std::string &symbol_name, Symbol data, Symbol index, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool keepdims=false, PickMode mode=PickMode::kClip)
Definition: op.h:1215
definition of shape
Symbol broadcast_greater(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3775
Symbol BatchNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, Symbol moving_mean, Symbol moving_var, double eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false, int axis=1, bool cudnn_off=false)
Definition: op.h:5006
Symbol rcbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:2308
PickMode
Definition: op.h:1154
Symbol signsgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:6220
Symbol broadcast_power(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:114
SoftmaxOutputNormalization
Definition: op.h:7516
Symbol SoftmaxOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxOutputNormalization normalization=SoftmaxOutputNormalization::kNull, bool out_grad=false, mx_float smooth_alpha=0)
Definition: op.h:7617
Symbol Flatten(const std::string &symbol_name, Symbol data)
Definition: op.h:350
Symbol BlockGrad(const std::string &symbol_name, Symbol data)
Definition: op.h:1693
LeakyReLUActType
Definition: op.h:5814
Symbol sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, bool lazy_update=true)
Definition: op.h:6385
Symbol arccos(const std::string &symbol_name, Symbol data)
Definition: op.h:4322
Symbol hard_sigmoid(const std::string &symbol_name, Symbol data, mx_float alpha=0.2, mx_float beta=0.5)
Definition: op.h:1627
Symbol argmax_channel(const std::string &symbol_name, Symbol data)
Definition: op.h:1143
Symbol batch_take(const std::string &symbol_name, Symbol a, Symbol indices)
Definition: op.h:3500
Symbol mean(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2604
Symbol LinearRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:6071
Symbol softsign(const std::string &symbol_name, Symbol data)
Definition: op.h:1653
Symbol sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, bool lazy_update=true)
Definition: op.h:6322
Symbol choose_element_0index(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:8288
Symbol Reshape(const std::string &symbol_name, Symbol data, Shape shape=Shape(), bool reverse=false, Shape target_shape=Shape(), bool keep_highest=false)
Definition: op.h:302
Symbol degrees(const std::string &symbol_name, Symbol data)
Definition: op.h:4376
Symbol shape_array(const std::string &symbol_name, Symbol data, dmlc::optional< int > lhs_begin=dmlc::optional< int >(), dmlc::optional< int > lhs_end=dmlc::optional< int >(), dmlc::optional< int > rhs_begin=dmlc::optional< int >(), dmlc::optional< int > rhs_end=dmlc::optional< int >())
Definition: op.h:1797
Symbol one_hot(const std::string &symbol_name, Symbol indices, int depth, double on_value=1, double off_value=0, One_hotDtype dtype=One_hotDtype::kFloat32)
Definition: op.h:3565
Symbol negative(const std::string &symbol_name, Symbol data)
Definition: op.h:1896
Symbol sum(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2567
Symbol GridGenerator(const std::string &symbol_name, Symbol data, GridGeneratorTransformType transform_type, Shape target_shape=Shape(0, 0))
Definition: op.h:7013
Symbol size_array(const std::string &symbol_name, Symbol data)
Definition: op.h:1826
Symbol argmax(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1068
Operator interface.
Definition: operator.h:43
Symbol interface.
Definition: symbol.h:72
MakeLossNormalization
Definition: op.h:8029
Symbol log10(const std::string &symbol_name, Symbol data)
Definition: op.h:2376
Symbol log2(const std::string &symbol_name, Symbol data)
Definition: op.h:2397