mxnet
op.h
Go to the documentation of this file.
1 
8 #ifndef MXNET_CPP_OP_H_
9 #define MXNET_CPP_OP_H_
10 
11 #include <string>
12 #include <vector>
13 #include "mxnet-cpp/base.h"
14 #include "mxnet-cpp/shape.h"
15 #include "mxnet-cpp/op_util.h"
16 #include "mxnet-cpp/operator.h"
17 #include "dmlc/optional.h"
18 
19 namespace mxnet {
20 namespace cpp {
21 
62 inline Symbol khatri_rao(const std::string& symbol_name,
63  const std::vector<Symbol>& args) {
64  return Operator("khatri_rao")
65 (args)
66  .CreateSymbol(symbol_name);
67 }
68 
84 inline Symbol Custom(const std::string& symbol_name,
85  const std::vector<Symbol>& data,
86  const std::string& op_type) {
87  return Operator("Custom")
88 (data)
89  .CreateSymbol(symbol_name);
90 }
91 
114 inline Symbol broadcast_power(const std::string& symbol_name,
115  Symbol lhs,
116  Symbol rhs) {
117  return Operator("broadcast_power")
118  .SetInput("lhs", lhs)
119  .SetInput("rhs", rhs)
120  .CreateSymbol(symbol_name);
121 }
122 
147 inline Symbol broadcast_maximum(const std::string& symbol_name,
148  Symbol lhs,
149  Symbol rhs) {
150  return Operator("broadcast_maximum")
151  .SetInput("lhs", lhs)
152  .SetInput("rhs", rhs)
153  .CreateSymbol(symbol_name);
154 }
155 
180 inline Symbol broadcast_minimum(const std::string& symbol_name,
181  Symbol lhs,
182  Symbol rhs) {
183  return Operator("broadcast_minimum")
184  .SetInput("lhs", lhs)
185  .SetInput("rhs", rhs)
186  .CreateSymbol(symbol_name);
187 }
188 
219 inline Symbol broadcast_hypot(const std::string& symbol_name,
220  Symbol lhs,
221  Symbol rhs) {
222  return Operator("broadcast_hypot")
223  .SetInput("lhs", lhs)
224  .SetInput("rhs", rhs)
225  .CreateSymbol(symbol_name);
226 }
227 
302 inline Symbol Reshape(const std::string& symbol_name,
303  Symbol data,
304  Shape shape = Shape(),
305  bool reverse = false,
306  Shape target_shape = Shape(),
307  bool keep_highest = false) {
308  return Operator("Reshape")
309  .SetParam("shape", shape)
310  .SetParam("reverse", reverse)
311  .SetParam("target_shape", target_shape)
312  .SetParam("keep_highest", keep_highest)
313  .SetInput("data", data)
314  .CreateSymbol(symbol_name);
315 }
316 
350 inline Symbol Flatten(const std::string& symbol_name,
351  Symbol data) {
352  return Operator("Flatten")
353  .SetInput("data", data)
354  .CreateSymbol(symbol_name);
355 }
356 
393 inline Symbol transpose(const std::string& symbol_name,
394  Symbol data,
395  Shape axes = Shape()) {
396  return Operator("transpose")
397  .SetParam("axes", axes)
398  .SetInput("data", data)
399  .CreateSymbol(symbol_name);
400 }
401 
417 inline Symbol expand_dims(const std::string& symbol_name,
418  Symbol data,
419  int axis) {
420  return Operator("expand_dims")
421  .SetParam("axis", axis)
422  .SetInput("data", data)
423  .CreateSymbol(symbol_name);
424 }
425 
481 inline Symbol slice(const std::string& symbol_name,
482  Symbol data,
483  Shape begin,
484  Shape end,
485  Shape step = Shape()) {
486  return Operator("slice")
487  .SetParam("begin", begin)
488  .SetParam("end", end)
489  .SetParam("step", step)
490  .SetInput("data", data)
491  .CreateSymbol(symbol_name);
492 }
493 
526 inline Symbol slice_axis(const std::string& symbol_name,
527  Symbol data,
528  int axis,
529  int begin,
530  dmlc::optional<int> end) {
531  return Operator("slice_axis")
532  .SetParam("axis", axis)
533  .SetParam("begin", begin)
534  .SetParam("end", end)
535  .SetInput("data", data)
536  .CreateSymbol(symbol_name);
537 }
538 
600 inline Symbol slice_like(const std::string& symbol_name,
601  Symbol data,
602  Symbol shape_like,
603  Shape axes = Shape()) {
604  return Operator("slice_like")
605  .SetParam("axes", axes)
606  .SetInput("data", data)
607  .SetInput("shape_like", shape_like)
608  .CreateSymbol(symbol_name);
609 }
610 
645 inline Symbol clip(const std::string& symbol_name,
646  Symbol data,
647  mx_float a_min,
648  mx_float a_max) {
649  return Operator("clip")
650  .SetParam("a_min", a_min)
651  .SetParam("a_max", a_max)
652  .SetInput("data", data)
653  .CreateSymbol(symbol_name);
654 }
655 
690 inline Symbol repeat(const std::string& symbol_name,
691  Symbol data,
692  int repeats,
693  dmlc::optional<int> axis = dmlc::optional<int>()) {
694  return Operator("repeat")
695  .SetParam("repeats", repeats)
696  .SetParam("axis", axis)
697  .SetInput("data", data)
698  .CreateSymbol(symbol_name);
699 }
700 
746 inline Symbol tile(const std::string& symbol_name,
747  Symbol data,
748  Shape reps) {
749  return Operator("tile")
750  .SetParam("reps", reps)
751  .SetInput("data", data)
752  .CreateSymbol(symbol_name);
753 }
754 
778 inline Symbol reverse(const std::string& symbol_name,
779  Symbol data,
780  Shape axis) {
781  return Operator("reverse")
782  .SetParam("axis", axis)
783  .SetInput("data", data)
784  .CreateSymbol(symbol_name);
785 }
786 
810 inline Symbol stack(const std::string& symbol_name,
811  const std::vector<Symbol>& data,
812  int num_args,
813  int axis = 0) {
814  return Operator("stack")
815  .SetParam("num_args", num_args)
816  .SetParam("axis", axis)
817 (data)
818  .CreateSymbol(symbol_name);
819 }
820 
843 inline Symbol squeeze(const std::string& symbol_name,
844  const std::vector<Symbol>& data,
845  dmlc::optional<Shape> axis = dmlc::optional<Shape>()) {
846  return Operator("squeeze")
847  .SetParam("axis", axis)
848 (data)
849  .CreateSymbol(symbol_name);
850 }
851 
893 inline Symbol depth_to_space(const std::string& symbol_name,
894  Symbol data,
895  int block_size) {
896  return Operator("depth_to_space")
897  .SetParam("block_size", block_size)
898  .SetInput("data", data)
899  .CreateSymbol(symbol_name);
900 }
901 
945 inline Symbol space_to_depth(const std::string& symbol_name,
946  Symbol data,
947  int block_size) {
948  return Operator("space_to_depth")
949  .SetParam("block_size", block_size)
950  .SetInput("data", data)
951  .CreateSymbol(symbol_name);
952 }
953 
977 inline Symbol zeros_like(const std::string& symbol_name,
978  Symbol data) {
979  return Operator("zeros_like")
980  .SetInput("data", data)
981  .CreateSymbol(symbol_name);
982 }
983 
1001 inline Symbol ones_like(const std::string& symbol_name,
1002  Symbol data) {
1003  return Operator("ones_like")
1004  .SetInput("data", data)
1005  .CreateSymbol(symbol_name);
1006 }
1007 
1040 inline Symbol broadcast_add(const std::string& symbol_name,
1041  Symbol lhs,
1042  Symbol rhs) {
1043  return Operator("broadcast_add")
1044  .SetInput("lhs", lhs)
1045  .SetInput("rhs", rhs)
1046  .CreateSymbol(symbol_name);
1047 }
1048 
1081 inline Symbol broadcast_sub(const std::string& symbol_name,
1082  Symbol lhs,
1083  Symbol rhs) {
1084  return Operator("broadcast_sub")
1085  .SetInput("lhs", lhs)
1086  .SetInput("rhs", rhs)
1087  .CreateSymbol(symbol_name);
1088 }
1089 
1116 inline Symbol broadcast_mul(const std::string& symbol_name,
1117  Symbol lhs,
1118  Symbol rhs) {
1119  return Operator("broadcast_mul")
1120  .SetInput("lhs", lhs)
1121  .SetInput("rhs", rhs)
1122  .CreateSymbol(symbol_name);
1123 }
1124 
1151 inline Symbol broadcast_div(const std::string& symbol_name,
1152  Symbol lhs,
1153  Symbol rhs) {
1154  return Operator("broadcast_div")
1155  .SetInput("lhs", lhs)
1156  .SetInput("rhs", rhs)
1157  .CreateSymbol(symbol_name);
1158 }
1159 
1182 inline Symbol broadcast_mod(const std::string& symbol_name,
1183  Symbol lhs,
1184  Symbol rhs) {
1185  return Operator("broadcast_mod")
1186  .SetInput("lhs", lhs)
1187  .SetInput("rhs", rhs)
1188  .CreateSymbol(symbol_name);
1189 }
1190 
1213 inline Symbol add_n(const std::string& symbol_name,
1214  const std::vector<Symbol>& args) {
1215  return Operator("add_n")
1216 (args)
1217  .CreateSymbol(symbol_name);
1218 }
1219 
1251 inline Symbol argmax(const std::string& symbol_name,
1252  Symbol data,
1253  dmlc::optional<int> axis = dmlc::optional<int>(),
1254  bool keepdims = false) {
1255  return Operator("argmax")
1256  .SetParam("axis", axis)
1257  .SetParam("keepdims", keepdims)
1258  .SetInput("data", data)
1259  .CreateSymbol(symbol_name);
1260 }
1261 
1293 inline Symbol argmin(const std::string& symbol_name,
1294  Symbol data,
1295  dmlc::optional<int> axis = dmlc::optional<int>(),
1296  bool keepdims = false) {
1297  return Operator("argmin")
1298  .SetParam("axis", axis)
1299  .SetParam("keepdims", keepdims)
1300  .SetInput("data", data)
1301  .CreateSymbol(symbol_name);
1302 }
1303 
1326 inline Symbol argmax_channel(const std::string& symbol_name,
1327  Symbol data) {
1328  return Operator("argmax_channel")
1329  .SetInput("data", data)
1330  .CreateSymbol(symbol_name);
1331 }
1332 
1337 enum class PickMode {
1338  kClip = 0,
1339  kWrap = 1
1340 };
1341 
1398 inline Symbol pick(const std::string& symbol_name,
1399  Symbol data,
1400  Symbol index,
1401  dmlc::optional<int> axis = dmlc::optional<int>(-1),
1402  bool keepdims = false,
1403  PickMode mode = PickMode::kClip) {
1404  static const char *PickModeValues[] = {
1405  "clip",
1406  "wrap"
1407  };
1408  return Operator("pick")
1409  .SetParam("axis", axis)
1410  .SetParam("keepdims", keepdims)
1411  .SetParam("mode", PickModeValues[int(mode)])
1412  .SetInput("data", data)
1413  .SetInput("index", index)
1414  .CreateSymbol(symbol_name);
1415 }
1416 
1421 enum class DotForwardStype {
1422  kNone = 0,
1423  kCsr = 1,
1424  kDefault = 2,
1425  kRow_sparse = 3
1426 };
1427 
1486 inline Symbol dot(const std::string& symbol_name,
1487  Symbol lhs,
1488  Symbol rhs,
1489  bool transpose_a = false,
1490  bool transpose_b = false,
1491  DotForwardStype forward_stype = DotForwardStype::kNone) {
1492  static const char *DotForwardStypeValues[] = {
1493  "None",
1494  "csr",
1495  "default",
1496  "row_sparse"
1497  };
1498  return Operator("dot")
1499  .SetParam("transpose_a", transpose_a)
1500  .SetParam("transpose_b", transpose_b)
1501  .SetParam("forward_stype", DotForwardStypeValues[int(forward_stype)])
1502  .SetInput("lhs", lhs)
1503  .SetInput("rhs", rhs)
1504  .CreateSymbol(symbol_name);
1505 }
1506 
1512  kNone = 0,
1513  kCsr = 1,
1514  kDefault = 2,
1515  kRow_sparse = 3
1516 };
1517 
1543 inline Symbol batch_dot(const std::string& symbol_name,
1544  Symbol lhs,
1545  Symbol rhs,
1546  bool transpose_a = false,
1547  bool transpose_b = false,
1549  static const char *Batch_dotForwardStypeValues[] = {
1550  "None",
1551  "csr",
1552  "default",
1553  "row_sparse"
1554  };
1555  return Operator("batch_dot")
1556  .SetParam("transpose_a", transpose_a)
1557  .SetParam("transpose_b", transpose_b)
1558  .SetParam("forward_stype", Batch_dotForwardStypeValues[int(forward_stype)])
1559  .SetInput("lhs", lhs)
1560  .SetInput("rhs", rhs)
1561  .CreateSymbol(symbol_name);
1562 }
1563 
1583 inline Symbol relu(const std::string& symbol_name,
1584  Symbol data) {
1585  return Operator("relu")
1586  .SetInput("data", data)
1587  .CreateSymbol(symbol_name);
1588 }
1589 
1605 inline Symbol sigmoid(const std::string& symbol_name,
1606  Symbol data) {
1607  return Operator("sigmoid")
1608  .SetInput("data", data)
1609  .CreateSymbol(symbol_name);
1610 }
1611 
1627 inline Symbol hard_sigmoid(const std::string& symbol_name,
1628  Symbol data,
1629  mx_float alpha = 0.2,
1630  mx_float beta = 0.5) {
1631  return Operator("hard_sigmoid")
1632  .SetParam("alpha", alpha)
1633  .SetParam("beta", beta)
1634  .SetInput("data", data)
1635  .CreateSymbol(symbol_name);
1636 }
1637 
1653 inline Symbol softsign(const std::string& symbol_name,
1654  Symbol data) {
1655  return Operator("softsign")
1656  .SetInput("data", data)
1657  .CreateSymbol(symbol_name);
1658 }
1659 
1693 inline Symbol BlockGrad(const std::string& symbol_name,
1694  Symbol data) {
1695  return Operator("BlockGrad")
1696  .SetInput("data", data)
1697  .CreateSymbol(symbol_name);
1698 }
1699 
1729 inline Symbol make_loss(const std::string& symbol_name,
1730  Symbol data) {
1731  return Operator("make_loss")
1732  .SetInput("data", data)
1733  .CreateSymbol(symbol_name);
1734 }
1735 
1770 inline Symbol reshape_like(const std::string& symbol_name,
1771  Symbol lhs,
1772  Symbol rhs) {
1773  return Operator("reshape_like")
1774  .SetInput("lhs", lhs)
1775  .SetInput("rhs", rhs)
1776  .CreateSymbol(symbol_name);
1777 }
1778 
1797 inline Symbol shape_array(const std::string& symbol_name,
1798  Symbol data,
1799  dmlc::optional<int> lhs_begin = dmlc::optional<int>(),
1800  dmlc::optional<int> lhs_end = dmlc::optional<int>(),
1801  dmlc::optional<int> rhs_begin = dmlc::optional<int>(),
1802  dmlc::optional<int> rhs_end = dmlc::optional<int>()) {
1803  return Operator("shape_array")
1804  .SetParam("lhs_begin", lhs_begin)
1805  .SetParam("lhs_end", lhs_end)
1806  .SetParam("rhs_begin", rhs_begin)
1807  .SetParam("rhs_end", rhs_end)
1808  .SetInput("data", data)
1809  .CreateSymbol(symbol_name);
1810 }
1811 
1826 inline Symbol size_array(const std::string& symbol_name,
1827  Symbol data) {
1828  return Operator("size_array")
1829  .SetInput("data", data)
1830  .CreateSymbol(symbol_name);
1831 }
1832 
1835 enum class CastDtype {
1836  kFloat16 = 0,
1837  kFloat32 = 1,
1838  kFloat64 = 2,
1839  kInt32 = 3,
1840  kInt64 = 4,
1841  kInt8 = 5,
1842  kUint8 = 6
1843 };
1844 
1864 inline Symbol Cast(const std::string& symbol_name,
1865  Symbol data,
1866  CastDtype dtype) {
1867  static const char *CastDtypeValues[] = {
1868  "float16",
1869  "float32",
1870  "float64",
1871  "int32",
1872  "int64",
1873  "int8",
1874  "uint8"
1875  };
1876  return Operator("Cast")
1877  .SetParam("dtype", CastDtypeValues[int(dtype)])
1878  .SetInput("data", data)
1879  .CreateSymbol(symbol_name);
1880 }
1881 
1896 inline Symbol negative(const std::string& symbol_name,
1897  Symbol data) {
1898  return Operator("negative")
1899  .SetInput("data", data)
1900  .CreateSymbol(symbol_name);
1901 }
1902 
1919 inline Symbol reciprocal(const std::string& symbol_name,
1920  Symbol data) {
1921  return Operator("reciprocal")
1922  .SetInput("data", data)
1923  .CreateSymbol(symbol_name);
1924 }
1925 
1946 inline Symbol abs(const std::string& symbol_name,
1947  Symbol data) {
1948  return Operator("abs")
1949  .SetInput("data", data)
1950  .CreateSymbol(symbol_name);
1951 }
1952 
1973 inline Symbol sign(const std::string& symbol_name,
1974  Symbol data) {
1975  return Operator("sign")
1976  .SetInput("data", data)
1977  .CreateSymbol(symbol_name);
1978 }
1979 
2000 inline Symbol round(const std::string& symbol_name,
2001  Symbol data) {
2002  return Operator("round")
2003  .SetInput("data", data)
2004  .CreateSymbol(symbol_name);
2005 }
2006 
2031 inline Symbol rint(const std::string& symbol_name,
2032  Symbol data) {
2033  return Operator("rint")
2034  .SetInput("data", data)
2035  .CreateSymbol(symbol_name);
2036 }
2037 
2060 inline Symbol ceil(const std::string& symbol_name,
2061  Symbol data) {
2062  return Operator("ceil")
2063  .SetInput("data", data)
2064  .CreateSymbol(symbol_name);
2065 }
2066 
2089 inline Symbol floor(const std::string& symbol_name,
2090  Symbol data) {
2091  return Operator("floor")
2092  .SetInput("data", data)
2093  .CreateSymbol(symbol_name);
2094 }
2095 
2119 inline Symbol trunc(const std::string& symbol_name,
2120  Symbol data) {
2121  return Operator("trunc")
2122  .SetInput("data", data)
2123  .CreateSymbol(symbol_name);
2124 }
2125 
2147 inline Symbol fix(const std::string& symbol_name,
2148  Symbol data) {
2149  return Operator("fix")
2150  .SetInput("data", data)
2151  .CreateSymbol(symbol_name);
2152 }
2153 
2177 inline Symbol square(const std::string& symbol_name,
2178  Symbol data) {
2179  return Operator("square")
2180  .SetInput("data", data)
2181  .CreateSymbol(symbol_name);
2182 }
2183 
2207 inline Symbol sqrt(const std::string& symbol_name,
2208  Symbol data) {
2209  return Operator("sqrt")
2210  .SetInput("data", data)
2211  .CreateSymbol(symbol_name);
2212 }
2213 
2233 inline Symbol rsqrt(const std::string& symbol_name,
2234  Symbol data) {
2235  return Operator("rsqrt")
2236  .SetInput("data", data)
2237  .CreateSymbol(symbol_name);
2238 }
2239 
2263 inline Symbol cbrt(const std::string& symbol_name,
2264  Symbol data) {
2265  return Operator("cbrt")
2266  .SetInput("data", data)
2267  .CreateSymbol(symbol_name);
2268 }
2269 
2287 inline Symbol rcbrt(const std::string& symbol_name,
2288  Symbol data) {
2289  return Operator("rcbrt")
2290  .SetInput("data", data)
2291  .CreateSymbol(symbol_name);
2292 }
2293 
2313 inline Symbol exp(const std::string& symbol_name,
2314  Symbol data) {
2315  return Operator("exp")
2316  .SetInput("data", data)
2317  .CreateSymbol(symbol_name);
2318 }
2319 
2334 inline Symbol log(const std::string& symbol_name,
2335  Symbol data) {
2336  return Operator("log")
2337  .SetInput("data", data)
2338  .CreateSymbol(symbol_name);
2339 }
2340 
2355 inline Symbol log10(const std::string& symbol_name,
2356  Symbol data) {
2357  return Operator("log10")
2358  .SetInput("data", data)
2359  .CreateSymbol(symbol_name);
2360 }
2361 
2376 inline Symbol log2(const std::string& symbol_name,
2377  Symbol data) {
2378  return Operator("log2")
2379  .SetInput("data", data)
2380  .CreateSymbol(symbol_name);
2381 }
2382 
2402 inline Symbol log1p(const std::string& symbol_name,
2403  Symbol data) {
2404  return Operator("log1p")
2405  .SetInput("data", data)
2406  .CreateSymbol(symbol_name);
2407 }
2408 
2427 inline Symbol expm1(const std::string& symbol_name,
2428  Symbol data) {
2429  return Operator("expm1")
2430  .SetInput("data", data)
2431  .CreateSymbol(symbol_name);
2432 }
2433 
2445 inline Symbol gamma(const std::string& symbol_name,
2446  Symbol data) {
2447  return Operator("gamma")
2448  .SetInput("data", data)
2449  .CreateSymbol(symbol_name);
2450 }
2451 
2463 inline Symbol gammaln(const std::string& symbol_name,
2464  Symbol data) {
2465  return Operator("gammaln")
2466  .SetInput("data", data)
2467  .CreateSymbol(symbol_name);
2468 }
2469 
2481 inline Symbol logical_not(const std::string& symbol_name,
2482  Symbol data) {
2483  return Operator("logical_not")
2484  .SetInput("data", data)
2485  .CreateSymbol(symbol_name);
2486 }
2487 
2546 inline Symbol sum(const std::string& symbol_name,
2547  Symbol data,
2548  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2549  bool keepdims = false,
2550  bool exclude = false) {
2551  return Operator("sum")
2552  .SetParam("axis", axis)
2553  .SetParam("keepdims", keepdims)
2554  .SetParam("exclude", exclude)
2555  .SetInput("data", data)
2556  .CreateSymbol(symbol_name);
2557 }
2558 
2583 inline Symbol mean(const std::string& symbol_name,
2584  Symbol data,
2585  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2586  bool keepdims = false,
2587  bool exclude = false) {
2588  return Operator("mean")
2589  .SetParam("axis", axis)
2590  .SetParam("keepdims", keepdims)
2591  .SetParam("exclude", exclude)
2592  .SetInput("data", data)
2593  .CreateSymbol(symbol_name);
2594 }
2595 
2620 inline Symbol prod(const std::string& symbol_name,
2621  Symbol data,
2622  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2623  bool keepdims = false,
2624  bool exclude = false) {
2625  return Operator("prod")
2626  .SetParam("axis", axis)
2627  .SetParam("keepdims", keepdims)
2628  .SetParam("exclude", exclude)
2629  .SetInput("data", data)
2630  .CreateSymbol(symbol_name);
2631 }
2632 
2659 inline Symbol nansum(const std::string& symbol_name,
2660  Symbol data,
2661  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2662  bool keepdims = false,
2663  bool exclude = false) {
2664  return Operator("nansum")
2665  .SetParam("axis", axis)
2666  .SetParam("keepdims", keepdims)
2667  .SetParam("exclude", exclude)
2668  .SetInput("data", data)
2669  .CreateSymbol(symbol_name);
2670 }
2671 
2698 inline Symbol nanprod(const std::string& symbol_name,
2699  Symbol data,
2700  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2701  bool keepdims = false,
2702  bool exclude = false) {
2703  return Operator("nanprod")
2704  .SetParam("axis", axis)
2705  .SetParam("keepdims", keepdims)
2706  .SetParam("exclude", exclude)
2707  .SetInput("data", data)
2708  .CreateSymbol(symbol_name);
2709 }
2710 
2735 inline Symbol max(const std::string& symbol_name,
2736  Symbol data,
2737  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2738  bool keepdims = false,
2739  bool exclude = false) {
2740  return Operator("max")
2741  .SetParam("axis", axis)
2742  .SetParam("keepdims", keepdims)
2743  .SetParam("exclude", exclude)
2744  .SetInput("data", data)
2745  .CreateSymbol(symbol_name);
2746 }
2747 
2772 inline Symbol min(const std::string& symbol_name,
2773  Symbol data,
2774  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2775  bool keepdims = false,
2776  bool exclude = false) {
2777  return Operator("min")
2778  .SetParam("axis", axis)
2779  .SetParam("keepdims", keepdims)
2780  .SetParam("exclude", exclude)
2781  .SetInput("data", data)
2782  .CreateSymbol(symbol_name);
2783 }
2784 
2814 inline Symbol broadcast_axis(const std::string& symbol_name,
2815  Symbol data,
2816  Shape axis = Shape(),
2817  Shape size = Shape()) {
2818  return Operator("broadcast_axis")
2819  .SetParam("axis", axis)
2820  .SetParam("size", size)
2821  .SetInput("data", data)
2822  .CreateSymbol(symbol_name);
2823 }
2824 
2853 inline Symbol broadcast_to(const std::string& symbol_name,
2854  Symbol data,
2855  Shape shape = Shape()) {
2856  return Operator("broadcast_to")
2857  .SetParam("shape", shape)
2858  .SetInput("data", data)
2859  .CreateSymbol(symbol_name);
2860 }
2861 
2886 inline Symbol broadcast_like(const std::string& symbol_name,
2887  Symbol lhs,
2888  Symbol rhs) {
2889  return Operator("broadcast_like")
2890  .SetInput("lhs", lhs)
2891  .SetInput("rhs", rhs)
2892  .CreateSymbol(symbol_name);
2893 }
2894 
2938 inline Symbol norm(const std::string& symbol_name,
2939  Symbol data,
2940  int ord = 2,
2941  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
2942  bool keepdims = false) {
2943  return Operator("norm")
2944  .SetParam("ord", ord)
2945  .SetParam("axis", axis)
2946  .SetParam("keepdims", keepdims)
2947  .SetInput("data", data)
2948  .CreateSymbol(symbol_name);
2949 }
2950 
2956 enum class TopkRetTyp {
2957  kBoth = 0,
2958  kIndices = 1,
2959  kMask = 2,
2960  kValue = 3
2961 };
2962 
2965 enum class TopkDtype {
2966  kFloat16 = 0,
2967  kFloat32 = 1,
2968  kFloat64 = 2,
2969  kInt32 = 3,
2970  kUint8 = 4
2971 };
2972 
3016 inline Symbol topk(const std::string& symbol_name,
3017  Symbol data,
3018  dmlc::optional<int> axis = dmlc::optional<int>(-1),
3019  int k = 1,
3020  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
3021  bool is_ascend = false,
3022  TopkDtype dtype = TopkDtype::kFloat32) {
3023  static const char *TopkRetTypValues[] = {
3024  "both",
3025  "indices",
3026  "mask",
3027  "value"
3028  };
3029  static const char *TopkDtypeValues[] = {
3030  "float16",
3031  "float32",
3032  "float64",
3033  "int32",
3034  "uint8"
3035  };
3036  return Operator("topk")
3037  .SetParam("axis", axis)
3038  .SetParam("k", k)
3039  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
3040  .SetParam("is_ascend", is_ascend)
3041  .SetParam("dtype", TopkDtypeValues[int(dtype)])
3042  .SetInput("data", data)
3043  .CreateSymbol(symbol_name);
3044 }
3045 
3078 inline Symbol sort(const std::string& symbol_name,
3079  Symbol data,
3080  dmlc::optional<int> axis = dmlc::optional<int>(-1),
3081  bool is_ascend = true) {
3082  return Operator("sort")
3083  .SetParam("axis", axis)
3084  .SetParam("is_ascend", is_ascend)
3085  .SetInput("data", data)
3086  .CreateSymbol(symbol_name);
3087 }
3088 
3092 enum class ArgsortDtype {
3093  kFloat16 = 0,
3094  kFloat32 = 1,
3095  kFloat64 = 2,
3096  kInt32 = 3,
3097  kUint8 = 4
3098 };
3099 
3132 inline Symbol argsort(const std::string& symbol_name,
3133  Symbol data,
3134  dmlc::optional<int> axis = dmlc::optional<int>(-1),
3135  bool is_ascend = true,
3137  static const char *ArgsortDtypeValues[] = {
3138  "float16",
3139  "float32",
3140  "float64",
3141  "int32",
3142  "uint8"
3143  };
3144  return Operator("argsort")
3145  .SetParam("axis", axis)
3146  .SetParam("is_ascend", is_ascend)
3147  .SetParam("dtype", ArgsortDtypeValues[int(dtype)])
3148  .SetInput("data", data)
3149  .CreateSymbol(symbol_name);
3150 }
3151 
3171 inline Symbol elemwise_add(const std::string& symbol_name,
3172  Symbol lhs,
3173  Symbol rhs) {
3174  return Operator("elemwise_add")
3175  .SetInput("lhs", lhs)
3176  .SetInput("rhs", rhs)
3177  .CreateSymbol(symbol_name);
3178 }
3179 
3199 inline Symbol elemwise_sub(const std::string& symbol_name,
3200  Symbol lhs,
3201  Symbol rhs) {
3202  return Operator("elemwise_sub")
3203  .SetInput("lhs", lhs)
3204  .SetInput("rhs", rhs)
3205  .CreateSymbol(symbol_name);
3206 }
3207 
3226 inline Symbol elemwise_mul(const std::string& symbol_name,
3227  Symbol lhs,
3228  Symbol rhs) {
3229  return Operator("elemwise_mul")
3230  .SetInput("lhs", lhs)
3231  .SetInput("rhs", rhs)
3232  .CreateSymbol(symbol_name);
3233 }
3234 
3246 inline Symbol elemwise_div(const std::string& symbol_name,
3247  Symbol lhs,
3248  Symbol rhs) {
3249  return Operator("elemwise_div")
3250  .SetInput("lhs", lhs)
3251  .SetInput("rhs", rhs)
3252  .CreateSymbol(symbol_name);
3253 }
3254 
3257 enum class EmbeddingDtype {
3258  kFloat16 = 0,
3259  kFloat32 = 1,
3260  kFloat64 = 2,
3261  kInt32 = 3,
3262  kInt64 = 4,
3263  kInt8 = 5,
3264  kUint8 = 6
3265 };
3266 
3330 inline Symbol Embedding(const std::string& symbol_name,
3331  Symbol data,
3332  Symbol weight,
3333  int input_dim,
3334  int output_dim,
3336  bool sparse_grad = false) {
3337  static const char *EmbeddingDtypeValues[] = {
3338  "float16",
3339  "float32",
3340  "float64",
3341  "int32",
3342  "int64",
3343  "int8",
3344  "uint8"
3345  };
3346  return Operator("Embedding")
3347  .SetParam("input_dim", input_dim)
3348  .SetParam("output_dim", output_dim)
3349  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
3350  .SetParam("sparse_grad", sparse_grad)
3351  .SetInput("data", data)
3352  .SetInput("weight", weight)
3353  .CreateSymbol(symbol_name);
3354 }
3355 
3360 enum class TakeMode {
3361  kClip = 0,
3362  kRaise = 1,
3363  kWrap = 2
3364 };
3365 
3419 inline Symbol take(const std::string& symbol_name,
3420  Symbol a,
3421  Symbol indices,
3422  int axis = 0,
3423  TakeMode mode = TakeMode::kClip) {
3424  static const char *TakeModeValues[] = {
3425  "clip",
3426  "raise",
3427  "wrap"
3428  };
3429  return Operator("take")
3430  .SetParam("axis", axis)
3431  .SetParam("mode", TakeModeValues[int(mode)])
3432  .SetInput("a", a)
3433  .SetInput("indices", indices)
3434  .CreateSymbol(symbol_name);
3435 }
3436 
3465 inline Symbol batch_take(const std::string& symbol_name,
3466  Symbol a,
3467  Symbol indices) {
3468  return Operator("batch_take")
3469  .SetInput("a", a)
3470  .SetInput("indices", indices)
3471  .CreateSymbol(symbol_name);
3472 }
3473 
3476 enum class One_hotDtype {
3477  kFloat16 = 0,
3478  kFloat32 = 1,
3479  kFloat64 = 2,
3480  kInt32 = 3,
3481  kInt64 = 4,
3482  kInt8 = 5,
3483  kUint8 = 6
3484 };
3485 
3530 inline Symbol one_hot(const std::string& symbol_name,
3531  Symbol indices,
3532  int depth,
3533  double on_value = 1,
3534  double off_value = 0,
3536  static const char *One_hotDtypeValues[] = {
3537  "float16",
3538  "float32",
3539  "float64",
3540  "int32",
3541  "int64",
3542  "int8",
3543  "uint8"
3544  };
3545  return Operator("one_hot")
3546  .SetParam("depth", depth)
3547  .SetParam("on_value", on_value)
3548  .SetParam("off_value", off_value)
3549  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
3550  .SetInput("indices", indices)
3551  .CreateSymbol(symbol_name);
3552 }
3553 
3585 inline Symbol gather_nd(const std::string& symbol_name,
3586  Symbol data,
3587  Symbol indices) {
3588  return Operator("gather_nd")
3589  .SetInput("data", data)
3590  .SetInput("indices", indices)
3591  .CreateSymbol(symbol_name);
3592 }
3593 
3645 inline Symbol scatter_nd(const std::string& symbol_name,
3646  Symbol data,
3647  Symbol indices,
3648  Shape shape) {
3649  return Operator("scatter_nd")
3650  .SetParam("shape", shape)
3651  .SetInput("data", data)
3652  .SetInput("indices", indices)
3653  .CreateSymbol(symbol_name);
3654 }
3655 
3678 inline Symbol broadcast_equal(const std::string& symbol_name,
3679  Symbol lhs,
3680  Symbol rhs) {
3681  return Operator("broadcast_equal")
3682  .SetInput("lhs", lhs)
3683  .SetInput("rhs", rhs)
3684  .CreateSymbol(symbol_name);
3685 }
3686 
3709 inline Symbol broadcast_not_equal(const std::string& symbol_name,
3710  Symbol lhs,
3711  Symbol rhs) {
3712  return Operator("broadcast_not_equal")
3713  .SetInput("lhs", lhs)
3714  .SetInput("rhs", rhs)
3715  .CreateSymbol(symbol_name);
3716 }
3717 
3740 inline Symbol broadcast_greater(const std::string& symbol_name,
3741  Symbol lhs,
3742  Symbol rhs) {
3743  return Operator("broadcast_greater")
3744  .SetInput("lhs", lhs)
3745  .SetInput("rhs", rhs)
3746  .CreateSymbol(symbol_name);
3747 }
3748 
3771 inline Symbol broadcast_greater_equal(const std::string& symbol_name,
3772  Symbol lhs,
3773  Symbol rhs) {
3774  return Operator("broadcast_greater_equal")
3775  .SetInput("lhs", lhs)
3776  .SetInput("rhs", rhs)
3777  .CreateSymbol(symbol_name);
3778 }
3779 
3802 inline Symbol broadcast_lesser(const std::string& symbol_name,
3803  Symbol lhs,
3804  Symbol rhs) {
3805  return Operator("broadcast_lesser")
3806  .SetInput("lhs", lhs)
3807  .SetInput("rhs", rhs)
3808  .CreateSymbol(symbol_name);
3809 }
3810 
3833 inline Symbol broadcast_lesser_equal(const std::string& symbol_name,
3834  Symbol lhs,
3835  Symbol rhs) {
3836  return Operator("broadcast_lesser_equal")
3837  .SetInput("lhs", lhs)
3838  .SetInput("rhs", rhs)
3839  .CreateSymbol(symbol_name);
3840 }
3841 
3864 inline Symbol broadcast_logical_and(const std::string& symbol_name,
3865  Symbol lhs,
3866  Symbol rhs) {
3867  return Operator("broadcast_logical_and")
3868  .SetInput("lhs", lhs)
3869  .SetInput("rhs", rhs)
3870  .CreateSymbol(symbol_name);
3871 }
3872 
3895 inline Symbol broadcast_logical_or(const std::string& symbol_name,
3896  Symbol lhs,
3897  Symbol rhs) {
3898  return Operator("broadcast_logical_or")
3899  .SetInput("lhs", lhs)
3900  .SetInput("rhs", rhs)
3901  .CreateSymbol(symbol_name);
3902 }
3903 
3926 inline Symbol broadcast_logical_xor(const std::string& symbol_name,
3927  Symbol lhs,
3928  Symbol rhs) {
3929  return Operator("broadcast_logical_xor")
3930  .SetInput("lhs", lhs)
3931  .SetInput("rhs", rhs)
3932  .CreateSymbol(symbol_name);
3933 }
3934 
3978 inline Symbol diag(const std::string& symbol_name,
3979  Symbol data,
3980  dmlc::optional<int> k = dmlc::optional<int>(0)) {
3981  return Operator("diag")
3982  .SetParam("k", k)
3983  .SetInput("data", data)
3984  .CreateSymbol(symbol_name);
3985 }
3986 
4022 inline Symbol where(const std::string& symbol_name,
4023  Symbol condition,
4024  Symbol x,
4025  Symbol y) {
4026  return Operator("where")
4027  .SetInput("condition", condition)
4028  .SetInput("x", x)
4029  .SetInput("y", y)
4030  .CreateSymbol(symbol_name);
4031 }
4032 
4058 inline Symbol smooth_l1(const std::string& symbol_name,
4059  Symbol data,
4060  mx_float scalar) {
4061  return Operator("smooth_l1")
4062  .SetParam("scalar", scalar)
4063  .SetInput("data", data)
4064  .CreateSymbol(symbol_name);
4065 }
4066 
4069 enum class Cast_storageStype {
4070  kCsr = 0,
4071  kDefault = 1,
4072  kRow_sparse = 2
4073 };
4074 
4120 inline Symbol cast_storage(const std::string& symbol_name,
4121  Symbol data,
4122  Cast_storageStype stype) {
4123  static const char *Cast_storageStypeValues[] = {
4124  "csr",
4125  "default",
4126  "row_sparse"
4127  };
4128  return Operator("cast_storage")
4129  .SetParam("stype", Cast_storageStypeValues[int(stype)])
4130  .SetInput("data", data)
4131  .CreateSymbol(symbol_name);
4132 }
4133 
4155 inline Symbol sin(const std::string& symbol_name,
4156  Symbol data) {
4157  return Operator("sin")
4158  .SetInput("data", data)
4159  .CreateSymbol(symbol_name);
4160 }
4161 
4179 inline Symbol cos(const std::string& symbol_name,
4180  Symbol data) {
4181  return Operator("cos")
4182  .SetInput("data", data)
4183  .CreateSymbol(symbol_name);
4184 }
4185 
4207 inline Symbol tan(const std::string& symbol_name,
4208  Symbol data) {
4209  return Operator("tan")
4210  .SetInput("data", data)
4211  .CreateSymbol(symbol_name);
4212 }
4213 
4236 inline Symbol arcsin(const std::string& symbol_name,
4237  Symbol data) {
4238  return Operator("arcsin")
4239  .SetInput("data", data)
4240  .CreateSymbol(symbol_name);
4241 }
4242 
4261 inline Symbol arccos(const std::string& symbol_name,
4262  Symbol data) {
4263  return Operator("arccos")
4264  .SetInput("data", data)
4265  .CreateSymbol(symbol_name);
4266 }
4267 
4289 inline Symbol arctan(const std::string& symbol_name,
4290  Symbol data) {
4291  return Operator("arctan")
4292  .SetInput("data", data)
4293  .CreateSymbol(symbol_name);
4294 }
4295 
4315 inline Symbol degrees(const std::string& symbol_name,
4316  Symbol data) {
4317  return Operator("degrees")
4318  .SetInput("data", data)
4319  .CreateSymbol(symbol_name);
4320 }
4321 
4341 inline Symbol radians(const std::string& symbol_name,
4342  Symbol data) {
4343  return Operator("radians")
4344  .SetInput("data", data)
4345  .CreateSymbol(symbol_name);
4346 }
4347 
4367 inline Symbol sinh(const std::string& symbol_name,
4368  Symbol data) {
4369  return Operator("sinh")
4370  .SetInput("data", data)
4371  .CreateSymbol(symbol_name);
4372 }
4373 
4389 inline Symbol cosh(const std::string& symbol_name,
4390  Symbol data) {
4391  return Operator("cosh")
4392  .SetInput("data", data)
4393  .CreateSymbol(symbol_name);
4394 }
4395 
4415 inline Symbol tanh(const std::string& symbol_name,
4416  Symbol data) {
4417  return Operator("tanh")
4418  .SetInput("data", data)
4419  .CreateSymbol(symbol_name);
4420 }
4421 
4439 inline Symbol arcsinh(const std::string& symbol_name,
4440  Symbol data) {
4441  return Operator("arcsinh")
4442  .SetInput("data", data)
4443  .CreateSymbol(symbol_name);
4444 }
4445 
4459 inline Symbol arccosh(const std::string& symbol_name,
4460  Symbol data) {
4461  return Operator("arccosh")
4462  .SetInput("data", data)
4463  .CreateSymbol(symbol_name);
4464 }
4465 
4483 inline Symbol arctanh(const std::string& symbol_name,
4484  Symbol data) {
4485  return Operator("arctanh")
4486  .SetInput("data", data)
4487  .CreateSymbol(symbol_name);
4488 }
4489 
4492 enum class PoolingPoolType {
4493  kAvg = 0,
4494  kLp = 1,
4495  kMax = 2,
4496  kSum = 3
4497 };
4498 
4502  kFull = 0,
4503  kValid = 1
4504 };
4505 
4574 inline Symbol Pooling(const std::string& symbol_name,
4575  Symbol data,
4576  Shape kernel = Shape(),
4578  bool global_pool = false,
4579  bool cudnn_off = false,
4581  Shape stride = Shape(),
4582  Shape pad = Shape(),
4583  dmlc::optional<int> p_value = dmlc::optional<int>(),
4584  dmlc::optional<bool> count_include_pad = dmlc::optional<bool>()) {
4585  static const char *PoolingPoolTypeValues[] = {
4586  "avg",
4587  "lp",
4588  "max",
4589  "sum"
4590  };
4591  static const char *PoolingPoolingConventionValues[] = {
4592  "full",
4593  "valid"
4594  };
4595  return Operator("Pooling")
4596  .SetParam("kernel", kernel)
4597  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
4598  .SetParam("global_pool", global_pool)
4599  .SetParam("cudnn_off", cudnn_off)
4600  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
4601  .SetParam("stride", stride)
4602  .SetParam("pad", pad)
4603  .SetParam("p_value", p_value)
4604  .SetParam("count_include_pad", count_include_pad)
4605  .SetInput("data", data)
4606  .CreateSymbol(symbol_name);
4607 }
4608 
4641 inline Symbol softmax(const std::string& symbol_name,
4642  Symbol data,
4643  int axis = -1,
4644  dmlc::optional<double> temperature = dmlc::optional<double>()) {
4645  return Operator("softmax")
4646  .SetParam("axis", axis)
4647  .SetParam("temperature", temperature)
4648  .SetInput("data", data)
4649  .CreateSymbol(symbol_name);
4650 }
4651 
4675 inline Symbol log_softmax(const std::string& symbol_name,
4676  Symbol data,
4677  int axis = -1,
4678  dmlc::optional<double> temperature = dmlc::optional<double>()) {
4679  return Operator("log_softmax")
4680  .SetParam("axis", axis)
4681  .SetParam("temperature", temperature)
4682  .SetInput("data", data)
4683  .CreateSymbol(symbol_name);
4684 }
4685 
4689  kNone = 0,
4690  kFastest = 1,
4691  kLimited_workspace = 2,
4692  kOff = 3
4693 };
4694 
4698  kNone = 0,
4699  kNCDHW = 1,
4700  kNCHW = 2,
4701  kNCW = 3,
4702  kNDHWC = 4,
4703  kNHWC = 5
4704 };
4705 
4735 inline Symbol Deconvolution(const std::string& symbol_name,
4736  Symbol data,
4737  Symbol weight,
4738  Symbol bias,
4739  Shape kernel,
4740  uint32_t num_filter,
4741  Shape stride = Shape(),
4742  Shape dilate = Shape(),
4743  Shape pad = Shape(),
4744  Shape adj = Shape(),
4745  Shape target_shape = Shape(),
4746  uint32_t num_group = 1,
4747  uint64_t workspace = 512,
4748  bool no_bias = true,
4750  bool cudnn_off = false,
4752  static const char *DeconvolutionCudnnTuneValues[] = {
4753  "None",
4754  "fastest",
4755  "limited_workspace",
4756  "off"
4757  };
4758  static const char *DeconvolutionLayoutValues[] = {
4759  "None",
4760  "NCDHW",
4761  "NCHW",
4762  "NCW",
4763  "NDHWC",
4764  "NHWC"
4765  };
4766  return Operator("Deconvolution")
4767  .SetParam("kernel", kernel)
4768  .SetParam("num_filter", num_filter)
4769  .SetParam("stride", stride)
4770  .SetParam("dilate", dilate)
4771  .SetParam("pad", pad)
4772  .SetParam("adj", adj)
4773  .SetParam("target_shape", target_shape)
4774  .SetParam("num_group", num_group)
4775  .SetParam("workspace", workspace)
4776  .SetParam("no_bias", no_bias)
4777  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
4778  .SetParam("cudnn_off", cudnn_off)
4779  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
4780  .SetInput("data", data)
4781  .SetInput("weight", weight)
4782  .SetInput("bias", bias)
4783  .CreateSymbol(symbol_name);
4784 }
4785 
4788 enum class ActivationActType {
4789  kRelu = 0,
4790  kSigmoid = 1,
4791  kSoftrelu = 2,
4792  kSoftsign = 3,
4793  kTanh = 4
4794 };
4795 
4815 inline Symbol Activation(const std::string& symbol_name,
4816  Symbol data,
4817  ActivationActType act_type) {
4818  static const char *ActivationActTypeValues[] = {
4819  "relu",
4820  "sigmoid",
4821  "softrelu",
4822  "softsign",
4823  "tanh"
4824  };
4825  return Operator("Activation")
4826  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
4827  .SetInput("data", data)
4828  .CreateSymbol(symbol_name);
4829 }
4830 
4900 inline Symbol BatchNorm(const std::string& symbol_name,
4901  Symbol data,
4902  Symbol gamma,
4903  Symbol beta,
4904  Symbol moving_mean,
4905  Symbol moving_var,
4906  double eps = 0.001,
4907  mx_float momentum = 0.9,
4908  bool fix_gamma = true,
4909  bool use_global_stats = false,
4910  bool output_mean_var = false,
4911  int axis = 1,
4912  bool cudnn_off = false) {
4913  return Operator("BatchNorm")
4914  .SetParam("eps", eps)
4915  .SetParam("momentum", momentum)
4916  .SetParam("fix_gamma", fix_gamma)
4917  .SetParam("use_global_stats", use_global_stats)
4918  .SetParam("output_mean_var", output_mean_var)
4919  .SetParam("axis", axis)
4920  .SetParam("cudnn_off", cudnn_off)
4921  .SetInput("data", data)
4922  .SetInput("gamma", gamma)
4923  .SetInput("beta", beta)
4924  .SetInput("moving_mean", moving_mean)
4925  .SetInput("moving_var", moving_var)
4926  .CreateSymbol(symbol_name);
4927 }
4928 
4932  kNone = 0,
4933  kFastest = 1,
4934  kLimited_workspace = 2,
4935  kOff = 3
4936 };
4937 
4941 enum class ConvolutionLayout {
4942  kNone = 0,
4943  kNCDHW = 1,
4944  kNCHW = 2,
4945  kNCW = 3,
4946  kNDHWC = 4,
4947  kNHWC = 5
4948 };
4949 
5047 inline Symbol Convolution(const std::string& symbol_name,
5048  Symbol data,
5049  Symbol weight,
5050  Symbol bias,
5051  Shape kernel,
5052  uint32_t num_filter,
5053  Shape stride = Shape(),
5054  Shape dilate = Shape(),
5055  Shape pad = Shape(),
5056  uint32_t num_group = 1,
5057  uint64_t workspace = 1024,
5058  bool no_bias = false,
5060  bool cudnn_off = false,
5062  static const char *ConvolutionCudnnTuneValues[] = {
5063  "None",
5064  "fastest",
5065  "limited_workspace",
5066  "off"
5067  };
5068  static const char *ConvolutionLayoutValues[] = {
5069  "None",
5070  "NCDHW",
5071  "NCHW",
5072  "NCW",
5073  "NDHWC",
5074  "NHWC"
5075  };
5076  return Operator("Convolution")
5077  .SetParam("kernel", kernel)
5078  .SetParam("num_filter", num_filter)
5079  .SetParam("stride", stride)
5080  .SetParam("dilate", dilate)
5081  .SetParam("pad", pad)
5082  .SetParam("num_group", num_group)
5083  .SetParam("workspace", workspace)
5084  .SetParam("no_bias", no_bias)
5085  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
5086  .SetParam("cudnn_off", cudnn_off)
5087  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
5088  .SetInput("data", data)
5089  .SetInput("weight", weight)
5090  .SetInput("bias", bias)
5091  .CreateSymbol(symbol_name);
5092 }
5093 
5097  kBilinear = 0,
5098  kNearest = 1
5099 };
5100 
5105  kConcat = 0,
5106  kSum = 1
5107 };
5108 
5124 inline Symbol UpSampling(const std::string& symbol_name,
5125  const std::vector<Symbol>& data,
5126  uint32_t scale,
5127  UpSamplingSampleType sample_type,
5128  int num_args,
5129  uint32_t num_filter = 0,
5131  uint64_t workspace = 512) {
5132  static const char *UpSamplingSampleTypeValues[] = {
5133  "bilinear",
5134  "nearest"
5135  };
5136  static const char *UpSamplingMultiInputModeValues[] = {
5137  "concat",
5138  "sum"
5139  };
5140  return Operator("UpSampling")
5141  .SetParam("scale", scale)
5142  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
5143  .SetParam("num_args", num_args)
5144  .SetParam("num_filter", num_filter)
5145  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
5146  .SetParam("workspace", workspace)
5147 (data)
5148  .CreateSymbol(symbol_name);
5149 }
5150 
5197 inline Symbol Concat(const std::string& symbol_name,
5198  const std::vector<Symbol>& data,
5199  int num_args,
5200  int dim = 1) {
5201  return Operator("Concat")
5202  .SetParam("num_args", num_args)
5203  .SetParam("dim", dim)
5204 (data)
5205  .CreateSymbol(symbol_name);
5206 }
5207 
5246 inline Symbol LayerNorm(const std::string& symbol_name,
5247  Symbol data,
5248  Symbol gamma,
5249  Symbol beta,
5250  int axis = -1,
5251  mx_float eps = 1e-05,
5252  bool output_mean_var = false) {
5253  return Operator("LayerNorm")
5254  .SetParam("axis", axis)
5255  .SetParam("eps", eps)
5256  .SetParam("output_mean_var", output_mean_var)
5257  .SetInput("data", data)
5258  .SetInput("gamma", gamma)
5259  .SetInput("beta", beta)
5260  .CreateSymbol(symbol_name);
5261 }
5262 
5290 inline Symbol LRN(const std::string& symbol_name,
5291  Symbol data,
5292  uint32_t nsize,
5293  mx_float alpha = 0.0001,
5294  mx_float beta = 0.75,
5295  mx_float knorm = 2) {
5296  return Operator("LRN")
5297  .SetParam("nsize", nsize)
5298  .SetParam("alpha", alpha)
5299  .SetParam("beta", beta)
5300  .SetParam("knorm", knorm)
5301  .SetInput("data", data)
5302  .CreateSymbol(symbol_name);
5303 }
5304 
5307 enum class DropoutMode {
5308  kAlways = 0,
5309  kTraining = 1
5310 };
5311 
5352 inline Symbol Dropout(const std::string& symbol_name,
5353  Symbol data,
5354  mx_float p = 0.5,
5356  Shape axes = Shape()) {
5357  static const char *DropoutModeValues[] = {
5358  "always",
5359  "training"
5360  };
5361  return Operator("Dropout")
5362  .SetParam("p", p)
5363  .SetParam("mode", DropoutModeValues[int(mode)])
5364  .SetParam("axes", axes)
5365  .SetInput("data", data)
5366  .CreateSymbol(symbol_name);
5367 }
5368 
5373  kChannel = 0,
5374  kInstance = 1
5375 };
5376 
5410 inline Symbol SoftmaxActivation(const std::string& symbol_name,
5411  Symbol data,
5413  static const char *SoftmaxActivationModeValues[] = {
5414  "channel",
5415  "instance"
5416  };
5417  return Operator("SoftmaxActivation")
5418  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
5419  .SetInput("data", data)
5420  .CreateSymbol(symbol_name);
5421 }
5422 
5466 inline Symbol FullyConnected(const std::string& symbol_name,
5467  Symbol data,
5468  Symbol weight,
5469  Symbol bias,
5470  int num_hidden,
5471  bool no_bias = false,
5472  bool flatten = true) {
5473  return Operator("FullyConnected")
5474  .SetParam("num_hidden", num_hidden)
5475  .SetParam("no_bias", no_bias)
5476  .SetParam("flatten", flatten)
5477  .SetInput("data", data)
5478  .SetInput("weight", weight)
5479  .SetInput("bias", bias)
5480  .CreateSymbol(symbol_name);
5481 }
5482 
5486 enum class PadMode {
5487  kConstant = 0,
5488  kEdge = 1,
5489  kReflect = 2
5490 };
5491 
5588 inline Symbol Pad(const std::string& symbol_name,
5589  Symbol data,
5590  PadMode mode,
5591  Shape pad_width,
5592  double constant_value = 0) {
5593  static const char *PadModeValues[] = {
5594  "constant",
5595  "edge",
5596  "reflect"
5597  };
5598  return Operator("Pad")
5599  .SetParam("mode", PadModeValues[int(mode)])
5600  .SetParam("pad_width", pad_width)
5601  .SetParam("constant_value", constant_value)
5602  .SetInput("data", data)
5603  .CreateSymbol(symbol_name);
5604 }
5605 
5608 enum class LeakyReLUActType {
5609  kElu = 0,
5610  kLeaky = 1,
5611  kPrelu = 2,
5612  kRrelu = 3,
5613  kSelu = 4
5614 };
5615 
5646 inline Symbol LeakyReLU(const std::string& symbol_name,
5647  Symbol data,
5648  Symbol gamma,
5650  mx_float slope = 0.25,
5651  mx_float lower_bound = 0.125,
5652  mx_float upper_bound = 0.334) {
5653  static const char *LeakyReLUActTypeValues[] = {
5654  "elu",
5655  "leaky",
5656  "prelu",
5657  "rrelu",
5658  "selu"
5659  };
5660  return Operator("LeakyReLU")
5661  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
5662  .SetParam("slope", slope)
5663  .SetParam("lower_bound", lower_bound)
5664  .SetParam("upper_bound", upper_bound)
5665  .SetInput("data", data)
5666  .SetInput("gamma", gamma)
5667  .CreateSymbol(symbol_name);
5668 }
5669 
5698 inline Symbol SwapAxis(const std::string& symbol_name,
5699  Symbol data,
5700  uint32_t dim1 = 0,
5701  uint32_t dim2 = 0) {
5702  return Operator("SwapAxis")
5703  .SetParam("dim1", dim1)
5704  .SetParam("dim2", dim2)
5705  .SetInput("data", data)
5706  .CreateSymbol(symbol_name);
5707 }
5708 
5769 inline Symbol BatchNorm_v1(const std::string& symbol_name,
5770  Symbol data,
5771  Symbol gamma,
5772  Symbol beta,
5773  mx_float eps = 0.001,
5774  mx_float momentum = 0.9,
5775  bool fix_gamma = true,
5776  bool use_global_stats = false,
5777  bool output_mean_var = false) {
5778  return Operator("BatchNorm_v1")
5779  .SetParam("eps", eps)
5780  .SetParam("momentum", momentum)
5781  .SetParam("fix_gamma", fix_gamma)
5782  .SetParam("use_global_stats", use_global_stats)
5783  .SetParam("output_mean_var", output_mean_var)
5784  .SetInput("data", data)
5785  .SetInput("gamma", gamma)
5786  .SetInput("beta", beta)
5787  .CreateSymbol(symbol_name);
5788 }
5789 
5827 inline Symbol softmax_cross_entropy(const std::string& symbol_name,
5828  Symbol data,
5829  Symbol label) {
5830  return Operator("softmax_cross_entropy")
5831  .SetInput("data", data)
5832  .SetInput("label", label)
5833  .CreateSymbol(symbol_name);
5834 }
5835 
5865 inline Symbol LinearRegressionOutput(const std::string& symbol_name,
5866  Symbol data,
5867  Symbol label,
5868  mx_float grad_scale = 1) {
5869  return Operator("LinearRegressionOutput")
5870  .SetParam("grad_scale", grad_scale)
5871  .SetInput("data", data)
5872  .SetInput("label", label)
5873  .CreateSymbol(symbol_name);
5874 }
5875 
5906 inline Symbol MAERegressionOutput(const std::string& symbol_name,
5907  Symbol data,
5908  Symbol label,
5909  mx_float grad_scale = 1) {
5910  return Operator("MAERegressionOutput")
5911  .SetParam("grad_scale", grad_scale)
5912  .SetInput("data", data)
5913  .SetInput("label", label)
5914  .CreateSymbol(symbol_name);
5915 }
5916 
5953 inline Symbol LogisticRegressionOutput(const std::string& symbol_name,
5954  Symbol data,
5955  Symbol label,
5956  mx_float grad_scale = 1) {
5957  return Operator("LogisticRegressionOutput")
5958  .SetParam("grad_scale", grad_scale)
5959  .SetInput("data", data)
5960  .SetInput("label", label)
5961  .CreateSymbol(symbol_name);
5962 }
5963 
5973 inline Symbol IdentityAttachKLSparseReg(const std::string& symbol_name,
5974  Symbol data,
5975  mx_float sparseness_target = 0.1,
5976  mx_float penalty = 0.001,
5977  mx_float momentum = 0.9) {
5978  return Operator("IdentityAttachKLSparseReg")
5979  .SetParam("sparseness_target", sparseness_target)
5980  .SetParam("penalty", penalty)
5981  .SetParam("momentum", momentum)
5982  .SetInput("data", data)
5983  .CreateSymbol(symbol_name);
5984 }
5985 
6014 inline Symbol signsgd_update(const std::string& symbol_name,
6015  Symbol weight,
6016  Symbol grad,
6017  mx_float lr,
6018  mx_float wd = 0,
6019  mx_float rescale_grad = 1,
6020  mx_float clip_gradient = -1) {
6021  return Operator("signsgd_update")
6022  .SetParam("lr", lr)
6023  .SetParam("wd", wd)
6024  .SetParam("rescale_grad", rescale_grad)
6025  .SetParam("clip_gradient", clip_gradient)
6026  .SetInput("weight", weight)
6027  .SetInput("grad", grad)
6028  .CreateSymbol(symbol_name);
6029 }
6030 
6065 inline Symbol signum_update(const std::string& symbol_name,
6066  Symbol weight,
6067  Symbol grad,
6068  Symbol mom,
6069  mx_float lr,
6070  mx_float momentum = 0,
6071  mx_float wd = 0,
6072  mx_float rescale_grad = 1,
6073  mx_float clip_gradient = -1,
6074  mx_float wd_lh = 0) {
6075  return Operator("signum_update")
6076  .SetParam("lr", lr)
6077  .SetParam("momentum", momentum)
6078  .SetParam("wd", wd)
6079  .SetParam("rescale_grad", rescale_grad)
6080  .SetParam("clip_gradient", clip_gradient)
6081  .SetParam("wd_lh", wd_lh)
6082  .SetInput("weight", weight)
6083  .SetInput("grad", grad)
6084  .SetInput("mom", mom)
6085  .CreateSymbol(symbol_name);
6086 }
6087 
6116 inline Symbol sgd_update(const std::string& symbol_name,
6117  Symbol weight,
6118  Symbol grad,
6119  mx_float lr,
6120  mx_float wd = 0,
6121  mx_float rescale_grad = 1,
6122  mx_float clip_gradient = -1,
6123  bool lazy_update = true) {
6124  return Operator("sgd_update")
6125  .SetParam("lr", lr)
6126  .SetParam("wd", wd)
6127  .SetParam("rescale_grad", rescale_grad)
6128  .SetParam("clip_gradient", clip_gradient)
6129  .SetParam("lazy_update", lazy_update)
6130  .SetInput("weight", weight)
6131  .SetInput("grad", grad)
6132  .CreateSymbol(symbol_name);
6133 }
6134 
6179 inline Symbol sgd_mom_update(const std::string& symbol_name,
6180  Symbol weight,
6181  Symbol grad,
6182  Symbol mom,
6183  mx_float lr,
6184  mx_float momentum = 0,
6185  mx_float wd = 0,
6186  mx_float rescale_grad = 1,
6187  mx_float clip_gradient = -1,
6188  bool lazy_update = true) {
6189  return Operator("sgd_mom_update")
6190  .SetParam("lr", lr)
6191  .SetParam("momentum", momentum)
6192  .SetParam("wd", wd)
6193  .SetParam("rescale_grad", rescale_grad)
6194  .SetParam("clip_gradient", clip_gradient)
6195  .SetParam("lazy_update", lazy_update)
6196  .SetInput("weight", weight)
6197  .SetInput("grad", grad)
6198  .SetInput("mom", mom)
6199  .CreateSymbol(symbol_name);
6200 }
6201 
6217 inline Symbol mp_sgd_update(const std::string& symbol_name,
6218  Symbol weight,
6219  Symbol grad,
6220  Symbol weight32,
6221  mx_float lr,
6222  mx_float wd = 0,
6223  mx_float rescale_grad = 1,
6224  mx_float clip_gradient = -1,
6225  bool lazy_update = true) {
6226  return Operator("mp_sgd_update")
6227  .SetParam("lr", lr)
6228  .SetParam("wd", wd)
6229  .SetParam("rescale_grad", rescale_grad)
6230  .SetParam("clip_gradient", clip_gradient)
6231  .SetParam("lazy_update", lazy_update)
6232  .SetInput("weight", weight)
6233  .SetInput("grad", grad)
6234  .SetInput("weight32", weight32)
6235  .CreateSymbol(symbol_name);
6236 }
6237 
6255 inline Symbol mp_sgd_mom_update(const std::string& symbol_name,
6256  Symbol weight,
6257  Symbol grad,
6258  Symbol mom,
6259  Symbol weight32,
6260  mx_float lr,
6261  mx_float momentum = 0,
6262  mx_float wd = 0,
6263  mx_float rescale_grad = 1,
6264  mx_float clip_gradient = -1,
6265  bool lazy_update = true) {
6266  return Operator("mp_sgd_mom_update")
6267  .SetParam("lr", lr)
6268  .SetParam("momentum", momentum)
6269  .SetParam("wd", wd)
6270  .SetParam("rescale_grad", rescale_grad)
6271  .SetParam("clip_gradient", clip_gradient)
6272  .SetParam("lazy_update", lazy_update)
6273  .SetInput("weight", weight)
6274  .SetInput("grad", grad)
6275  .SetInput("mom", mom)
6276  .SetInput("weight32", weight32)
6277  .CreateSymbol(symbol_name);
6278 }
6279 
6315 inline Symbol ftml_update(const std::string& symbol_name,
6316  Symbol weight,
6317  Symbol grad,
6318  Symbol d,
6319  Symbol v,
6320  Symbol z,
6321  mx_float lr,
6322  int t,
6323  mx_float beta1 = 0.6,
6324  mx_float beta2 = 0.999,
6325  double epsilon = 1e-08,
6326  mx_float wd = 0,
6327  mx_float rescale_grad = 1,
6328  mx_float clip_grad = -1) {
6329  return Operator("ftml_update")
6330  .SetParam("lr", lr)
6331  .SetParam("t", t)
6332  .SetParam("beta1", beta1)
6333  .SetParam("beta2", beta2)
6334  .SetParam("epsilon", epsilon)
6335  .SetParam("wd", wd)
6336  .SetParam("rescale_grad", rescale_grad)
6337  .SetParam("clip_grad", clip_grad)
6338  .SetInput("weight", weight)
6339  .SetInput("grad", grad)
6340  .SetInput("d", d)
6341  .SetInput("v", v)
6342  .SetInput("z", z)
6343  .CreateSymbol(symbol_name);
6344 }
6345 
6395 inline Symbol adam_update(const std::string& symbol_name,
6396  Symbol weight,
6397  Symbol grad,
6398  Symbol mean,
6399  Symbol var,
6400  mx_float lr,
6401  mx_float beta1 = 0.9,
6402  mx_float beta2 = 0.999,
6403  mx_float epsilon = 1e-08,
6404  mx_float wd = 0,
6405  mx_float rescale_grad = 1,
6406  mx_float clip_gradient = -1,
6407  bool lazy_update = true) {
6408  return Operator("adam_update")
6409  .SetParam("lr", lr)
6410  .SetParam("beta1", beta1)
6411  .SetParam("beta2", beta2)
6412  .SetParam("epsilon", epsilon)
6413  .SetParam("wd", wd)
6414  .SetParam("rescale_grad", rescale_grad)
6415  .SetParam("clip_gradient", clip_gradient)
6416  .SetParam("lazy_update", lazy_update)
6417  .SetInput("weight", weight)
6418  .SetInput("grad", grad)
6419  .SetInput("mean", mean)
6420  .SetInput("var", var)
6421  .CreateSymbol(symbol_name);
6422 }
6423 
6477 inline Symbol rmsprop_update(const std::string& symbol_name,
6478  Symbol weight,
6479  Symbol grad,
6480  Symbol n,
6481  mx_float lr,
6482  mx_float gamma1 = 0.95,
6483  mx_float epsilon = 1e-08,
6484  mx_float wd = 0,
6485  mx_float rescale_grad = 1,
6486  mx_float clip_gradient = -1,
6487  mx_float clip_weights = -1) {
6488  return Operator("rmsprop_update")
6489  .SetParam("lr", lr)
6490  .SetParam("gamma1", gamma1)
6491  .SetParam("epsilon", epsilon)
6492  .SetParam("wd", wd)
6493  .SetParam("rescale_grad", rescale_grad)
6494  .SetParam("clip_gradient", clip_gradient)
6495  .SetParam("clip_weights", clip_weights)
6496  .SetInput("weight", weight)
6497  .SetInput("grad", grad)
6498  .SetInput("n", n)
6499  .CreateSymbol(symbol_name);
6500 }
6501 
6547 inline Symbol rmspropalex_update(const std::string& symbol_name,
6548  Symbol weight,
6549  Symbol grad,
6550  Symbol n,
6551  Symbol g,
6552  Symbol delta,
6553  mx_float lr,
6554  mx_float gamma1 = 0.95,
6555  mx_float gamma2 = 0.9,
6556  mx_float epsilon = 1e-08,
6557  mx_float wd = 0,
6558  mx_float rescale_grad = 1,
6559  mx_float clip_gradient = -1,
6560  mx_float clip_weights = -1) {
6561  return Operator("rmspropalex_update")
6562  .SetParam("lr", lr)
6563  .SetParam("gamma1", gamma1)
6564  .SetParam("gamma2", gamma2)
6565  .SetParam("epsilon", epsilon)
6566  .SetParam("wd", wd)
6567  .SetParam("rescale_grad", rescale_grad)
6568  .SetParam("clip_gradient", clip_gradient)
6569  .SetParam("clip_weights", clip_weights)
6570  .SetInput("weight", weight)
6571  .SetInput("grad", grad)
6572  .SetInput("n", n)
6573  .SetInput("g", g)
6574  .SetInput("delta", delta)
6575  .CreateSymbol(symbol_name);
6576 }
6577 
6617 inline Symbol ftrl_update(const std::string& symbol_name,
6618  Symbol weight,
6619  Symbol grad,
6620  Symbol z,
6621  Symbol n,
6622  mx_float lr,
6623  mx_float lamda1 = 0.01,
6624  mx_float beta = 1,
6625  mx_float wd = 0,
6626  mx_float rescale_grad = 1,
6627  mx_float clip_gradient = -1) {
6628  return Operator("ftrl_update")
6629  .SetParam("lr", lr)
6630  .SetParam("lamda1", lamda1)
6631  .SetParam("beta", beta)
6632  .SetParam("wd", wd)
6633  .SetParam("rescale_grad", rescale_grad)
6634  .SetParam("clip_gradient", clip_gradient)
6635  .SetInput("weight", weight)
6636  .SetInput("grad", grad)
6637  .SetInput("z", z)
6638  .SetInput("n", n)
6639  .CreateSymbol(symbol_name);
6640 }
6641 
6713 inline Symbol SliceChannel(const std::string& symbol_name,
6714  Symbol data,
6715  int num_outputs,
6716  int axis = 1,
6717  bool squeeze_axis = false) {
6718  return Operator("SliceChannel")
6719  .SetParam("num_outputs", num_outputs)
6720  .SetParam("axis", axis)
6721  .SetParam("squeeze_axis", squeeze_axis)
6722  .SetInput("data", data)
6723  .CreateSymbol(symbol_name);
6724 }
6725 
6776 inline Symbol InstanceNorm(const std::string& symbol_name,
6777  Symbol data,
6778  Symbol gamma,
6779  Symbol beta,
6780  mx_float eps = 0.001) {
6781  return Operator("InstanceNorm")
6782  .SetParam("eps", eps)
6783  .SetInput("data", data)
6784  .SetInput("gamma", gamma)
6785  .SetInput("beta", beta)
6786  .CreateSymbol(symbol_name);
6787 }
6788 
6793  kAffine = 0,
6794  kWarp = 1
6795 };
6796 
6807 inline Symbol GridGenerator(const std::string& symbol_name,
6808  Symbol data,
6809  GridGeneratorTransformType transform_type,
6810  Shape target_shape = Shape(0,0)) {
6811  static const char *GridGeneratorTransformTypeValues[] = {
6812  "affine",
6813  "warp"
6814  };
6815  return Operator("GridGenerator")
6816  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
6817  .SetParam("target_shape", target_shape)
6818  .SetInput("data", data)
6819  .CreateSymbol(symbol_name);
6820 }
6821 
6825  kAvg = 0,
6826  kMax = 1,
6827  kSum = 2
6828 };
6829 
6833  kFull = 0,
6834  kValid = 1
6835 };
6836 
6888 inline Symbol Pooling_v1(const std::string& symbol_name,
6889  Symbol data,
6890  Shape kernel = Shape(),
6892  bool global_pool = false,
6894  Shape stride = Shape(),
6895  Shape pad = Shape()) {
6896  static const char *Pooling_v1PoolTypeValues[] = {
6897  "avg",
6898  "max",
6899  "sum"
6900  };
6901  static const char *Pooling_v1PoolingConventionValues[] = {
6902  "full",
6903  "valid"
6904  };
6905  return Operator("Pooling_v1")
6906  .SetParam("kernel", kernel)
6907  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
6908  .SetParam("global_pool", global_pool)
6909  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
6910  .SetParam("stride", stride)
6911  .SetParam("pad", pad)
6912  .SetInput("data", data)
6913  .CreateSymbol(symbol_name);
6914 }
6915 
6918 enum class RNNMode {
6919  kGru = 0,
6920  kLstm = 1,
6921  kRnn_relu = 2,
6922  kRnn_tanh = 3
6923 };
6924 
6991 inline Symbol RNN(const std::string& symbol_name,
6992  Symbol data,
6993  Symbol parameters,
6994  Symbol state,
6995  Symbol state_cell,
6996  uint32_t state_size,
6997  uint32_t num_layers,
6998  RNNMode mode,
6999  bool bidirectional = false,
7000  mx_float p = 0,
7001  bool state_outputs = false,
7002  dmlc::optional<int> projection_size = dmlc::optional<int>(),
7003  dmlc::optional<double> lstm_state_clip_min = dmlc::optional<double>(),
7004  dmlc::optional<double> lstm_state_clip_max = dmlc::optional<double>(),
7005  bool lstm_state_clip_nan = false) {
7006  static const char *RNNModeValues[] = {
7007  "gru",
7008  "lstm",
7009  "rnn_relu",
7010  "rnn_tanh"
7011  };
7012  return Operator("RNN")
7013  .SetParam("state_size", state_size)
7014  .SetParam("num_layers", num_layers)
7015  .SetParam("mode", RNNModeValues[int(mode)])
7016  .SetParam("bidirectional", bidirectional)
7017  .SetParam("p", p)
7018  .SetParam("state_outputs", state_outputs)
7019  .SetParam("projection_size", projection_size)
7020  .SetParam("lstm_state_clip_min", lstm_state_clip_min)
7021  .SetParam("lstm_state_clip_max", lstm_state_clip_max)
7022  .SetParam("lstm_state_clip_nan", lstm_state_clip_nan)
7023  .SetInput("data", data)
7024  .SetInput("parameters", parameters)
7025  .SetInput("state", state)
7026  .SetInput("state_cell", state_cell)
7027  .CreateSymbol(symbol_name);
7028 }
7029 
7040  kNone = 0,
7041  kFastest = 1,
7042  kLimited_workspace = 2,
7043  kOff = 3
7044 };
7045 
7050  kNone = 0,
7051  kNCDHW = 1,
7052  kNCHW = 2,
7053  kNDHWC = 3,
7054  kNHWC = 4
7055 };
7056 
7087 inline Symbol Convolution_v1(const std::string& symbol_name,
7088  Symbol data,
7089  Symbol weight,
7090  Symbol bias,
7091  Shape kernel,
7092  uint32_t num_filter,
7093  Shape stride = Shape(),
7094  Shape dilate = Shape(),
7095  Shape pad = Shape(),
7096  uint32_t num_group = 1,
7097  uint64_t workspace = 1024,
7098  bool no_bias = false,
7100  bool cudnn_off = false,
7102  static const char *Convolution_v1CudnnTuneValues[] = {
7103  "None",
7104  "fastest",
7105  "limited_workspace",
7106  "off"
7107  };
7108  static const char *Convolution_v1LayoutValues[] = {
7109  "None",
7110  "NCDHW",
7111  "NCHW",
7112  "NDHWC",
7113  "NHWC"
7114  };
7115  return Operator("Convolution_v1")
7116  .SetParam("kernel", kernel)
7117  .SetParam("num_filter", num_filter)
7118  .SetParam("stride", stride)
7119  .SetParam("dilate", dilate)
7120  .SetParam("pad", pad)
7121  .SetParam("num_group", num_group)
7122  .SetParam("workspace", workspace)
7123  .SetParam("no_bias", no_bias)
7124  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
7125  .SetParam("cudnn_off", cudnn_off)
7126  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
7127  .SetInput("data", data)
7128  .SetInput("weight", weight)
7129  .SetInput("bias", bias)
7130  .CreateSymbol(symbol_name);
7131 }
7132 
7153 inline Symbol Crop(const std::string& symbol_name,
7154  const std::vector<Symbol>& data,
7155  int num_args,
7156  Shape offset = Shape(0,0),
7157  Shape h_w = Shape(0,0),
7158  bool center_crop = false) {
7159  return Operator("Crop")
7160  .SetParam("num_args", num_args)
7161  .SetParam("offset", offset)
7162  .SetParam("h_w", h_w)
7163  .SetParam("center_crop", center_crop)
7164 (data)
7165  .CreateSymbol(symbol_name);
7166 }
7167 
7244 inline Symbol SequenceReverse(const std::string& symbol_name,
7245  Symbol data,
7246  Symbol sequence_length,
7247  bool use_sequence_length = false,
7248  int axis = 0) {
7249  return Operator("SequenceReverse")
7250  .SetParam("use_sequence_length", use_sequence_length)
7251  .SetParam("axis", axis)
7252  .SetInput("data", data)
7253  .SetInput("sequence_length", sequence_length)
7254  .CreateSymbol(symbol_name);
7255 }
7256 
7260  kAffine = 0
7261 };
7262 
7266  kBilinear = 0
7267 };
7268 
7279 inline Symbol SpatialTransformer(const std::string& symbol_name,
7280  Symbol data,
7281  Symbol loc,
7282  SpatialTransformerTransformType transform_type,
7283  SpatialTransformerSamplerType sampler_type,
7284  Shape target_shape = Shape(0,0)) {
7285  static const char *SpatialTransformerTransformTypeValues[] = {
7286  "affine"
7287  };
7288  static const char *SpatialTransformerSamplerTypeValues[] = {
7289  "bilinear"
7290  };
7291  return Operator("SpatialTransformer")
7292  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
7293  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
7294  .SetParam("target_shape", target_shape)
7295  .SetInput("data", data)
7296  .SetInput("loc", loc)
7297  .CreateSymbol(symbol_name);
7298 }
7299 
7303  kBatch = 0,
7304  kNull = 1,
7305  kValid = 2
7306 };
7307 
7403 inline Symbol SoftmaxOutput(const std::string& symbol_name,
7404  Symbol data,
7405  Symbol label,
7406  mx_float grad_scale = 1,
7407  mx_float ignore_label = -1,
7408  bool multi_output = false,
7409  bool use_ignore = false,
7410  bool preserve_shape = false,
7412  bool out_grad = false,
7413  mx_float smooth_alpha = 0) {
7414  static const char *SoftmaxOutputNormalizationValues[] = {
7415  "batch",
7416  "null",
7417  "valid"
7418  };
7419  return Operator("SoftmaxOutput")
7420  .SetParam("grad_scale", grad_scale)
7421  .SetParam("ignore_label", ignore_label)
7422  .SetParam("multi_output", multi_output)
7423  .SetParam("use_ignore", use_ignore)
7424  .SetParam("preserve_shape", preserve_shape)
7425  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
7426  .SetParam("out_grad", out_grad)
7427  .SetParam("smooth_alpha", smooth_alpha)
7428  .SetInput("data", data)
7429  .SetInput("label", label)
7430  .CreateSymbol(symbol_name);
7431 }
7432 
7436  kBatch = 0,
7437  kNull = 1,
7438  kValid = 2
7439 };
7440 
7468 inline Symbol Softmax(const std::string& symbol_name,
7469  Symbol data,
7470  mx_float grad_scale = 1,
7471  mx_float ignore_label = -1,
7472  bool multi_output = false,
7473  bool use_ignore = false,
7474  bool preserve_shape = false,
7476  bool out_grad = false,
7477  mx_float smooth_alpha = 0) {
7478  static const char *SoftmaxNormalizationValues[] = {
7479  "batch",
7480  "null",
7481  "valid"
7482  };
7483  return Operator("Softmax")
7484  .SetParam("grad_scale", grad_scale)
7485  .SetParam("ignore_label", ignore_label)
7486  .SetParam("multi_output", multi_output)
7487  .SetParam("use_ignore", use_ignore)
7488  .SetParam("preserve_shape", preserve_shape)
7489  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
7490  .SetParam("out_grad", out_grad)
7491  .SetParam("smooth_alpha", smooth_alpha)
7492  .SetInput("data", data)
7493  .CreateSymbol(symbol_name);
7494 }
7495 
7576 inline Symbol BilinearSampler(const std::string& symbol_name,
7577  Symbol data,
7578  Symbol grid) {
7579  return Operator("BilinearSampler")
7580  .SetInput("data", data)
7581  .SetInput("grid", grid)
7582  .CreateSymbol(symbol_name);
7583 }
7584 
7641 inline Symbol ROIPooling(const std::string& symbol_name,
7642  Symbol data,
7643  Symbol rois,
7644  Shape pooled_size,
7645  mx_float spatial_scale) {
7646  return Operator("ROIPooling")
7647  .SetParam("pooled_size", pooled_size)
7648  .SetParam("spatial_scale", spatial_scale)
7649  .SetInput("data", data)
7650  .SetInput("rois", rois)
7651  .CreateSymbol(symbol_name);
7652 }
7653 
7709 inline Symbol SequenceLast(const std::string& symbol_name,
7710  Symbol data,
7711  Symbol sequence_length,
7712  bool use_sequence_length = false,
7713  int axis = 0) {
7714  return Operator("SequenceLast")
7715  .SetParam("use_sequence_length", use_sequence_length)
7716  .SetParam("axis", axis)
7717  .SetInput("data", data)
7718  .SetInput("sequence_length", sequence_length)
7719  .CreateSymbol(symbol_name);
7720 }
7721 
7725  kChannel = 0,
7726  kInstance = 1,
7727  kSpatial = 2
7728 };
7729 
7792 inline Symbol L2Normalization(const std::string& symbol_name,
7793  Symbol data,
7794  mx_float eps = 1e-10,
7796  static const char *L2NormalizationModeValues[] = {
7797  "channel",
7798  "instance",
7799  "spatial"
7800  };
7801  return Operator("L2Normalization")
7802  .SetParam("eps", eps)
7803  .SetParam("mode", L2NormalizationModeValues[int(mode)])
7804  .SetInput("data", data)
7805  .CreateSymbol(symbol_name);
7806 }
7807 
7813  kBatch = 0,
7814  kNull = 1,
7815  kValid = 2
7816 };
7817 
7852 inline Symbol MakeLoss(const std::string& symbol_name,
7853  Symbol data,
7854  mx_float grad_scale = 1,
7855  mx_float valid_thresh = 0,
7857  static const char *MakeLossNormalizationValues[] = {
7858  "batch",
7859  "null",
7860  "valid"
7861  };
7862  return Operator("MakeLoss")
7863  .SetParam("grad_scale", grad_scale)
7864  .SetParam("valid_thresh", valid_thresh)
7865  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
7866  .SetInput("data", data)
7867  .CreateSymbol(symbol_name);
7868 }
7869 
7885 inline Symbol SVMOutput(const std::string& symbol_name,
7886  Symbol data,
7887  Symbol label,
7888  mx_float margin = 1,
7889  mx_float regularization_coefficient = 1,
7890  bool use_linear = false) {
7891  return Operator("SVMOutput")
7892  .SetParam("margin", margin)
7893  .SetParam("regularization_coefficient", regularization_coefficient)
7894  .SetParam("use_linear", use_linear)
7895  .SetInput("data", data)
7896  .SetInput("label", label)
7897  .CreateSymbol(symbol_name);
7898 }
7899 
7949 inline Symbol Correlation(const std::string& symbol_name,
7950  Symbol data1,
7951  Symbol data2,
7952  uint32_t kernel_size = 1,
7953  uint32_t max_displacement = 1,
7954  uint32_t stride1 = 1,
7955  uint32_t stride2 = 1,
7956  uint32_t pad_size = 0,
7957  bool is_multiply = true) {
7958  return Operator("Correlation")
7959  .SetParam("kernel_size", kernel_size)
7960  .SetParam("max_displacement", max_displacement)
7961  .SetParam("stride1", stride1)
7962  .SetParam("stride2", stride2)
7963  .SetParam("pad_size", pad_size)
7964  .SetParam("is_multiply", is_multiply)
7965  .SetInput("data1", data1)
7966  .SetInput("data2", data2)
7967  .CreateSymbol(symbol_name);
7968 }
7969 
8048 inline Symbol SequenceMask(const std::string& symbol_name,
8049  Symbol data,
8050  Symbol sequence_length,
8051  bool use_sequence_length = false,
8052  mx_float value = 0,
8053  int axis = 0) {
8054  return Operator("SequenceMask")
8055  .SetParam("use_sequence_length", use_sequence_length)
8056  .SetParam("value", value)
8057  .SetParam("axis", axis)
8058  .SetInput("data", data)
8059  .SetInput("sequence_length", sequence_length)
8060  .CreateSymbol(symbol_name);
8061 }
8062 
8071 inline Symbol choose_element_0index(const std::string& symbol_name,
8072  Symbol lhs,
8073  Symbol rhs) {
8074  return Operator("choose_element_0index")
8075  .SetInput("lhs", lhs)
8076  .SetInput("rhs", rhs)
8077  .CreateSymbol(symbol_name);
8078 }
8079 
8089 inline Symbol fill_element_0index(const std::string& symbol_name,
8090  Symbol lhs,
8091  Symbol mhs,
8092  Symbol rhs) {
8093  return Operator("fill_element_0index")
8094  .SetInput("lhs", lhs)
8095  .SetInput("mhs", mhs)
8096  .SetInput("rhs", rhs)
8097  .CreateSymbol(symbol_name);
8098 }
8099 
8139 inline Symbol khatri_rao(const std::vector<Symbol>& args) {
8140  return Operator("khatri_rao")
8141 (args)
8142  .CreateSymbol();
8143 }
8144 
8159 inline Symbol Custom(const std::vector<Symbol>& data,
8160  const std::string& op_type) {
8161  return Operator("Custom")
8162 (data)
8163  .CreateSymbol();
8164 }
8165 
8188  Symbol rhs) {
8189  return Operator("broadcast_power")
8190  .SetInput("lhs", lhs)
8191  .SetInput("rhs", rhs)
8192  .CreateSymbol();
8193 }
8194 
8219  Symbol rhs) {
8220  return Operator("broadcast_maximum")
8221  .SetInput("lhs", lhs)
8222  .SetInput("rhs", rhs)
8223  .CreateSymbol();
8224 }
8225 
8250  Symbol rhs) {
8251  return Operator("broadcast_minimum")
8252  .SetInput("lhs", lhs)
8253  .SetInput("rhs", rhs)
8254  .CreateSymbol();
8255 }
8256 
8287  Symbol rhs) {
8288  return Operator("broadcast_hypot")
8289  .SetInput("lhs", lhs)
8290  .SetInput("rhs", rhs)
8291  .CreateSymbol();
8292 }
8293 
8367 inline Symbol Reshape(Symbol data,
8368  Shape shape = Shape(),
8369  bool reverse = false,
8370  Shape target_shape = Shape(),
8371  bool keep_highest = false) {
8372  return Operator("Reshape")
8373  .SetParam("shape", shape)
8374  .SetParam("reverse", reverse)
8375  .SetParam("target_shape", target_shape)
8376  .SetParam("keep_highest", keep_highest)
8377  .SetInput("data", data)
8378  .CreateSymbol();
8379 }
8380 
8413 inline Symbol Flatten(Symbol data) {
8414  return Operator("Flatten")
8415  .SetInput("data", data)
8416  .CreateSymbol();
8417 }
8418 
8455  Shape axes = Shape()) {
8456  return Operator("transpose")
8457  .SetParam("axes", axes)
8458  .SetInput("data", data)
8459  .CreateSymbol();
8460 }
8461 
8477  int axis) {
8478  return Operator("expand_dims")
8479  .SetParam("axis", axis)
8480  .SetInput("data", data)
8481  .CreateSymbol();
8482 }
8483 
8538 inline Symbol slice(Symbol data,
8539  Shape begin,
8540  Shape end,
8541  Shape step = Shape()) {
8542  return Operator("slice")
8543  .SetParam("begin", begin)
8544  .SetParam("end", end)
8545  .SetParam("step", step)
8546  .SetInput("data", data)
8547  .CreateSymbol();
8548 }
8549 
8582  int axis,
8583  int begin,
8584  dmlc::optional<int> end) {
8585  return Operator("slice_axis")
8586  .SetParam("axis", axis)
8587  .SetParam("begin", begin)
8588  .SetParam("end", end)
8589  .SetInput("data", data)
8590  .CreateSymbol();
8591 }
8592 
8654  Symbol shape_like,
8655  Shape axes = Shape()) {
8656  return Operator("slice_like")
8657  .SetParam("axes", axes)
8658  .SetInput("data", data)
8659  .SetInput("shape_like", shape_like)
8660  .CreateSymbol();
8661 }
8662 
8696 inline Symbol clip(Symbol data,
8697  mx_float a_min,
8698  mx_float a_max) {
8699  return Operator("clip")
8700  .SetParam("a_min", a_min)
8701  .SetParam("a_max", a_max)
8702  .SetInput("data", data)
8703  .CreateSymbol();
8704 }
8705 
8739 inline Symbol repeat(Symbol data,
8740  int repeats,
8741  dmlc::optional<int> axis = dmlc::optional<int>()) {
8742  return Operator("repeat")
8743  .SetParam("repeats", repeats)
8744  .SetParam("axis", axis)
8745  .SetInput("data", data)
8746  .CreateSymbol();
8747 }
8748 
8793 inline Symbol tile(Symbol data,
8794  Shape reps) {
8795  return Operator("tile")
8796  .SetParam("reps", reps)
8797  .SetInput("data", data)
8798  .CreateSymbol();
8799 }
8800 
8823 inline Symbol reverse(Symbol data,
8824  Shape axis) {
8825  return Operator("reverse")
8826  .SetParam("axis", axis)
8827  .SetInput("data", data)
8828  .CreateSymbol();
8829 }
8830 
8853 inline Symbol stack(const std::vector<Symbol>& data,
8854  int num_args,
8855  int axis = 0) {
8856  return Operator("stack")
8857  .SetParam("num_args", num_args)
8858  .SetParam("axis", axis)
8859 (data)
8860  .CreateSymbol();
8861 }
8862 
8884 inline Symbol squeeze(const std::vector<Symbol>& data,
8885  dmlc::optional<Shape> axis = dmlc::optional<Shape>()) {
8886  return Operator("squeeze")
8887  .SetParam("axis", axis)
8888 (data)
8889  .CreateSymbol();
8890 }
8891 
8933  int block_size) {
8934  return Operator("depth_to_space")
8935  .SetParam("block_size", block_size)
8936  .SetInput("data", data)
8937  .CreateSymbol();
8938 }
8939 
8983  int block_size) {
8984  return Operator("space_to_depth")
8985  .SetParam("block_size", block_size)
8986  .SetInput("data", data)
8987  .CreateSymbol();
8988 }
8989 
9012 inline Symbol zeros_like(Symbol data) {
9013  return Operator("zeros_like")
9014  .SetInput("data", data)
9015  .CreateSymbol();
9016 }
9017 
9034 inline Symbol ones_like(Symbol data) {
9035  return Operator("ones_like")
9036  .SetInput("data", data)
9037  .CreateSymbol();
9038 }
9039 
9072  Symbol rhs) {
9073  return Operator("broadcast_add")
9074  .SetInput("lhs", lhs)
9075  .SetInput("rhs", rhs)
9076  .CreateSymbol();
9077 }
9078 
9111  Symbol rhs) {
9112  return Operator("broadcast_sub")
9113  .SetInput("lhs", lhs)
9114  .SetInput("rhs", rhs)
9115  .CreateSymbol();
9116 }
9117 
9144  Symbol rhs) {
9145  return Operator("broadcast_mul")
9146  .SetInput("lhs", lhs)
9147  .SetInput("rhs", rhs)
9148  .CreateSymbol();
9149 }
9150 
9177  Symbol rhs) {
9178  return Operator("broadcast_div")
9179  .SetInput("lhs", lhs)
9180  .SetInput("rhs", rhs)
9181  .CreateSymbol();
9182 }
9183 
9206  Symbol rhs) {
9207  return Operator("broadcast_mod")
9208  .SetInput("lhs", lhs)
9209  .SetInput("rhs", rhs)
9210  .CreateSymbol();
9211 }
9212 
9234 inline Symbol add_n(const std::vector<Symbol>& args) {
9235  return Operator("add_n")
9236 (args)
9237  .CreateSymbol();
9238 }
9239 
9270 inline Symbol argmax(Symbol data,
9271  dmlc::optional<int> axis = dmlc::optional<int>(),
9272  bool keepdims = false) {
9273  return Operator("argmax")
9274  .SetParam("axis", axis)
9275  .SetParam("keepdims", keepdims)
9276  .SetInput("data", data)
9277  .CreateSymbol();
9278 }
9279 
9310 inline Symbol argmin(Symbol data,
9311  dmlc::optional<int> axis = dmlc::optional<int>(),
9312  bool keepdims = false) {
9313  return Operator("argmin")
9314  .SetParam("axis", axis)
9315  .SetParam("keepdims", keepdims)
9316  .SetInput("data", data)
9317  .CreateSymbol();
9318 }
9319 
9342  return Operator("argmax_channel")
9343  .SetInput("data", data)
9344  .CreateSymbol();
9345 }
9346 
9402 inline Symbol pick(Symbol data,
9403  Symbol index,
9404  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9405  bool keepdims = false,
9406  PickMode mode = PickMode::kClip) {
9407  static const char *PickModeValues[] = {
9408  "clip",
9409  "wrap"
9410  };
9411  return Operator("pick")
9412  .SetParam("axis", axis)
9413  .SetParam("keepdims", keepdims)
9414  .SetParam("mode", PickModeValues[int(mode)])
9415  .SetInput("data", data)
9416  .SetInput("index", index)
9417  .CreateSymbol();
9418 }
9419 
9477 inline Symbol dot(Symbol lhs,
9478  Symbol rhs,
9479  bool transpose_a = false,
9480  bool transpose_b = false,
9481  DotForwardStype forward_stype = DotForwardStype::kNone) {
9482  static const char *DotForwardStypeValues[] = {
9483  "None",
9484  "csr",
9485  "default",
9486  "row_sparse"
9487  };
9488  return Operator("dot")
9489  .SetParam("transpose_a", transpose_a)
9490  .SetParam("transpose_b", transpose_b)
9491  .SetParam("forward_stype", DotForwardStypeValues[int(forward_stype)])
9492  .SetInput("lhs", lhs)
9493  .SetInput("rhs", rhs)
9494  .CreateSymbol();
9495 }
9496 
9522  Symbol rhs,
9523  bool transpose_a = false,
9524  bool transpose_b = false,
9526  static const char *Batch_dotForwardStypeValues[] = {
9527  "None",
9528  "csr",
9529  "default",
9530  "row_sparse"
9531  };
9532  return Operator("batch_dot")
9533  .SetParam("transpose_a", transpose_a)
9534  .SetParam("transpose_b", transpose_b)
9535  .SetParam("forward_stype", Batch_dotForwardStypeValues[int(forward_stype)])
9536  .SetInput("lhs", lhs)
9537  .SetInput("rhs", rhs)
9538  .CreateSymbol();
9539 }
9540 
9559 inline Symbol relu(Symbol data) {
9560  return Operator("relu")
9561  .SetInput("data", data)
9562  .CreateSymbol();
9563 }
9564 
9579 inline Symbol sigmoid(Symbol data) {
9580  return Operator("sigmoid")
9581  .SetInput("data", data)
9582  .CreateSymbol();
9583 }
9584 
9600  mx_float alpha = 0.2,
9601  mx_float beta = 0.5) {
9602  return Operator("hard_sigmoid")
9603  .SetParam("alpha", alpha)
9604  .SetParam("beta", beta)
9605  .SetInput("data", data)
9606  .CreateSymbol();
9607 }
9608 
9623 inline Symbol softsign(Symbol data) {
9624  return Operator("softsign")
9625  .SetInput("data", data)
9626  .CreateSymbol();
9627 }
9628 
9661 inline Symbol BlockGrad(Symbol data) {
9662  return Operator("BlockGrad")
9663  .SetInput("data", data)
9664  .CreateSymbol();
9665 }
9666 
9695 inline Symbol make_loss(Symbol data) {
9696  return Operator("make_loss")
9697  .SetInput("data", data)
9698  .CreateSymbol();
9699 }
9700 
9735  Symbol rhs) {
9736  return Operator("reshape_like")
9737  .SetInput("lhs", lhs)
9738  .SetInput("rhs", rhs)
9739  .CreateSymbol();
9740 }
9741 
9760  dmlc::optional<int> lhs_begin = dmlc::optional<int>(),
9761  dmlc::optional<int> lhs_end = dmlc::optional<int>(),
9762  dmlc::optional<int> rhs_begin = dmlc::optional<int>(),
9763  dmlc::optional<int> rhs_end = dmlc::optional<int>()) {
9764  return Operator("shape_array")
9765  .SetParam("lhs_begin", lhs_begin)
9766  .SetParam("lhs_end", lhs_end)
9767  .SetParam("rhs_begin", rhs_begin)
9768  .SetParam("rhs_end", rhs_end)
9769  .SetInput("data", data)
9770  .CreateSymbol();
9771 }
9772 
9786 inline Symbol size_array(Symbol data) {
9787  return Operator("size_array")
9788  .SetInput("data", data)
9789  .CreateSymbol();
9790 }
9791 
9810 inline Symbol Cast(Symbol data,
9811  CastDtype dtype) {
9812  static const char *CastDtypeValues[] = {
9813  "float16",
9814  "float32",
9815  "float64",
9816  "int32",
9817  "int64",
9818  "int8",
9819  "uint8"
9820  };
9821  return Operator("Cast")
9822  .SetParam("dtype", CastDtypeValues[int(dtype)])
9823  .SetInput("data", data)
9824  .CreateSymbol();
9825 }
9826 
9840 inline Symbol negative(Symbol data) {
9841  return Operator("negative")
9842  .SetInput("data", data)
9843  .CreateSymbol();
9844 }
9845 
9861 inline Symbol reciprocal(Symbol data) {
9862  return Operator("reciprocal")
9863  .SetInput("data", data)
9864  .CreateSymbol();
9865 }
9866 
9886 inline Symbol abs(Symbol data) {
9887  return Operator("abs")
9888  .SetInput("data", data)
9889  .CreateSymbol();
9890 }
9891 
9911 inline Symbol sign(Symbol data) {
9912  return Operator("sign")
9913  .SetInput("data", data)
9914  .CreateSymbol();
9915 }
9916 
9936 inline Symbol round(Symbol data) {
9937  return Operator("round")
9938  .SetInput("data", data)
9939  .CreateSymbol();
9940 }
9941 
9965 inline Symbol rint(Symbol data) {
9966  return Operator("rint")
9967  .SetInput("data", data)
9968  .CreateSymbol();
9969 }
9970 
9992 inline Symbol ceil(Symbol data) {
9993  return Operator("ceil")
9994  .SetInput("data", data)
9995  .CreateSymbol();
9996 }
9997 
10019 inline Symbol floor(Symbol data) {
10020  return Operator("floor")
10021  .SetInput("data", data)
10022  .CreateSymbol();
10023 }
10024 
10047 inline Symbol trunc(Symbol data) {
10048  return Operator("trunc")
10049  .SetInput("data", data)
10050  .CreateSymbol();
10051 }
10052 
10073 inline Symbol fix(Symbol data) {
10074  return Operator("fix")
10075  .SetInput("data", data)
10076  .CreateSymbol();
10077 }
10078 
10101 inline Symbol square(Symbol data) {
10102  return Operator("square")
10103  .SetInput("data", data)
10104  .CreateSymbol();
10105 }
10106 
10129 inline Symbol sqrt(Symbol data) {
10130  return Operator("sqrt")
10131  .SetInput("data", data)
10132  .CreateSymbol();
10133 }
10134 
10153 inline Symbol rsqrt(Symbol data) {
10154  return Operator("rsqrt")
10155  .SetInput("data", data)
10156  .CreateSymbol();
10157 }
10158 
10181 inline Symbol cbrt(Symbol data) {
10182  return Operator("cbrt")
10183  .SetInput("data", data)
10184  .CreateSymbol();
10185 }
10186 
10203 inline Symbol rcbrt(Symbol data) {
10204  return Operator("rcbrt")
10205  .SetInput("data", data)
10206  .CreateSymbol();
10207 }
10208 
10227 inline Symbol exp(Symbol data) {
10228  return Operator("exp")
10229  .SetInput("data", data)
10230  .CreateSymbol();
10231 }
10232 
10246 inline Symbol log(Symbol data) {
10247  return Operator("log")
10248  .SetInput("data", data)
10249  .CreateSymbol();
10250 }
10251 
10265 inline Symbol log10(Symbol data) {
10266  return Operator("log10")
10267  .SetInput("data", data)
10268  .CreateSymbol();
10269 }
10270 
10284 inline Symbol log2(Symbol data) {
10285  return Operator("log2")
10286  .SetInput("data", data)
10287  .CreateSymbol();
10288 }
10289 
10308 inline Symbol log1p(Symbol data) {
10309  return Operator("log1p")
10310  .SetInput("data", data)
10311  .CreateSymbol();
10312 }
10313 
10331 inline Symbol expm1(Symbol data) {
10332  return Operator("expm1")
10333  .SetInput("data", data)
10334  .CreateSymbol();
10335 }
10336 
10347 inline Symbol gamma(Symbol data) {
10348  return Operator("gamma")
10349  .SetInput("data", data)
10350  .CreateSymbol();
10351 }
10352 
10363 inline Symbol gammaln(Symbol data) {
10364  return Operator("gammaln")
10365  .SetInput("data", data)
10366  .CreateSymbol();
10367 }
10368 
10379 inline Symbol logical_not(Symbol data) {
10380  return Operator("logical_not")
10381  .SetInput("data", data)
10382  .CreateSymbol();
10383 }
10384 
10442 inline Symbol sum(Symbol data,
10443  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10444  bool keepdims = false,
10445  bool exclude = false) {
10446  return Operator("sum")
10447  .SetParam("axis", axis)
10448  .SetParam("keepdims", keepdims)
10449  .SetParam("exclude", exclude)
10450  .SetInput("data", data)
10451  .CreateSymbol();
10452 }
10453 
10477 inline Symbol mean(Symbol data,
10478  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10479  bool keepdims = false,
10480  bool exclude = false) {
10481  return Operator("mean")
10482  .SetParam("axis", axis)
10483  .SetParam("keepdims", keepdims)
10484  .SetParam("exclude", exclude)
10485  .SetInput("data", data)
10486  .CreateSymbol();
10487 }
10488 
10512 inline Symbol prod(Symbol data,
10513  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10514  bool keepdims = false,
10515  bool exclude = false) {
10516  return Operator("prod")
10517  .SetParam("axis", axis)
10518  .SetParam("keepdims", keepdims)
10519  .SetParam("exclude", exclude)
10520  .SetInput("data", data)
10521  .CreateSymbol();
10522 }
10523 
10549 inline Symbol nansum(Symbol data,
10550  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10551  bool keepdims = false,
10552  bool exclude = false) {
10553  return Operator("nansum")
10554  .SetParam("axis", axis)
10555  .SetParam("keepdims", keepdims)
10556  .SetParam("exclude", exclude)
10557  .SetInput("data", data)
10558  .CreateSymbol();
10559 }
10560 
10586 inline Symbol nanprod(Symbol data,
10587  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10588  bool keepdims = false,
10589  bool exclude = false) {
10590  return Operator("nanprod")
10591  .SetParam("axis", axis)
10592  .SetParam("keepdims", keepdims)
10593  .SetParam("exclude", exclude)
10594  .SetInput("data", data)
10595  .CreateSymbol();
10596 }
10597 
10621 inline Symbol max(Symbol data,
10622  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10623  bool keepdims = false,
10624  bool exclude = false) {
10625  return Operator("max")
10626  .SetParam("axis", axis)
10627  .SetParam("keepdims", keepdims)
10628  .SetParam("exclude", exclude)
10629  .SetInput("data", data)
10630  .CreateSymbol();
10631 }
10632 
10656 inline Symbol min(Symbol data,
10657  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10658  bool keepdims = false,
10659  bool exclude = false) {
10660  return Operator("min")
10661  .SetParam("axis", axis)
10662  .SetParam("keepdims", keepdims)
10663  .SetParam("exclude", exclude)
10664  .SetInput("data", data)
10665  .CreateSymbol();
10666 }
10667 
10697  Shape axis = Shape(),
10698  Shape size = Shape()) {
10699  return Operator("broadcast_axis")
10700  .SetParam("axis", axis)
10701  .SetParam("size", size)
10702  .SetInput("data", data)
10703  .CreateSymbol();
10704 }
10705 
10734  Shape shape = Shape()) {
10735  return Operator("broadcast_to")
10736  .SetParam("shape", shape)
10737  .SetInput("data", data)
10738  .CreateSymbol();
10739 }
10740 
10765  Symbol rhs) {
10766  return Operator("broadcast_like")
10767  .SetInput("lhs", lhs)
10768  .SetInput("rhs", rhs)
10769  .CreateSymbol();
10770 }
10771 
10814 inline Symbol norm(Symbol data,
10815  int ord = 2,
10816  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
10817  bool keepdims = false) {
10818  return Operator("norm")
10819  .SetParam("ord", ord)
10820  .SetParam("axis", axis)
10821  .SetParam("keepdims", keepdims)
10822  .SetInput("data", data)
10823  .CreateSymbol();
10824 }
10825 
10868 inline Symbol topk(Symbol data,
10869  dmlc::optional<int> axis = dmlc::optional<int>(-1),
10870  int k = 1,
10871  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
10872  bool is_ascend = false,
10873  TopkDtype dtype = TopkDtype::kFloat32) {
10874  static const char *TopkRetTypValues[] = {
10875  "both",
10876  "indices",
10877  "mask",
10878  "value"
10879  };
10880  static const char *TopkDtypeValues[] = {
10881  "float16",
10882  "float32",
10883  "float64",
10884  "int32",
10885  "uint8"
10886  };
10887  return Operator("topk")
10888  .SetParam("axis", axis)
10889  .SetParam("k", k)
10890  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
10891  .SetParam("is_ascend", is_ascend)
10892  .SetParam("dtype", TopkDtypeValues[int(dtype)])
10893  .SetInput("data", data)
10894  .CreateSymbol();
10895 }
10896 
10928 inline Symbol sort(Symbol data,
10929  dmlc::optional<int> axis = dmlc::optional<int>(-1),
10930  bool is_ascend = true) {
10931  return Operator("sort")
10932  .SetParam("axis", axis)
10933  .SetParam("is_ascend", is_ascend)
10934  .SetInput("data", data)
10935  .CreateSymbol();
10936 }
10937 
10969 inline Symbol argsort(Symbol data,
10970  dmlc::optional<int> axis = dmlc::optional<int>(-1),
10971  bool is_ascend = true,
10973  static const char *ArgsortDtypeValues[] = {
10974  "float16",
10975  "float32",
10976  "float64",
10977  "int32",
10978  "uint8"
10979  };
10980  return Operator("argsort")
10981  .SetParam("axis", axis)
10982  .SetParam("is_ascend", is_ascend)
10983  .SetParam("dtype", ArgsortDtypeValues[int(dtype)])
10984  .SetInput("data", data)
10985  .CreateSymbol();
10986 }
10987 
11007  Symbol rhs) {
11008  return Operator("elemwise_add")
11009  .SetInput("lhs", lhs)
11010  .SetInput("rhs", rhs)
11011  .CreateSymbol();
11012 }
11013 
11033  Symbol rhs) {
11034  return Operator("elemwise_sub")
11035  .SetInput("lhs", lhs)
11036  .SetInput("rhs", rhs)
11037  .CreateSymbol();
11038 }
11039 
11058  Symbol rhs) {
11059  return Operator("elemwise_mul")
11060  .SetInput("lhs", lhs)
11061  .SetInput("rhs", rhs)
11062  .CreateSymbol();
11063 }
11064 
11076  Symbol rhs) {
11077  return Operator("elemwise_div")
11078  .SetInput("lhs", lhs)
11079  .SetInput("rhs", rhs)
11080  .CreateSymbol();
11081 }
11082 
11146  Symbol weight,
11147  int input_dim,
11148  int output_dim,
11150  bool sparse_grad = false) {
11151  static const char *EmbeddingDtypeValues[] = {
11152  "float16",
11153  "float32",
11154  "float64",
11155  "int32",
11156  "int64",
11157  "int8",
11158  "uint8"
11159  };
11160  return Operator("Embedding")
11161  .SetParam("input_dim", input_dim)
11162  .SetParam("output_dim", output_dim)
11163  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
11164  .SetParam("sparse_grad", sparse_grad)
11165  .SetInput("data", data)
11166  .SetInput("weight", weight)
11167  .CreateSymbol();
11168 }
11169 
11222 inline Symbol take(Symbol a,
11223  Symbol indices,
11224  int axis = 0,
11225  TakeMode mode = TakeMode::kClip) {
11226  static const char *TakeModeValues[] = {
11227  "clip",
11228  "raise",
11229  "wrap"
11230  };
11231  return Operator("take")
11232  .SetParam("axis", axis)
11233  .SetParam("mode", TakeModeValues[int(mode)])
11234  .SetInput("a", a)
11235  .SetInput("indices", indices)
11236  .CreateSymbol();
11237 }
11238 
11267  Symbol indices) {
11268  return Operator("batch_take")
11269  .SetInput("a", a)
11270  .SetInput("indices", indices)
11271  .CreateSymbol();
11272 }
11273 
11317 inline Symbol one_hot(Symbol indices,
11318  int depth,
11319  double on_value = 1,
11320  double off_value = 0,
11322  static const char *One_hotDtypeValues[] = {
11323  "float16",
11324  "float32",
11325  "float64",
11326  "int32",
11327  "int64",
11328  "int8",
11329  "uint8"
11330  };
11331  return Operator("one_hot")
11332  .SetParam("depth", depth)
11333  .SetParam("on_value", on_value)
11334  .SetParam("off_value", off_value)
11335  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
11336  .SetInput("indices", indices)
11337  .CreateSymbol();
11338 }
11339 
11371  Symbol indices) {
11372  return Operator("gather_nd")
11373  .SetInput("data", data)
11374  .SetInput("indices", indices)
11375  .CreateSymbol();
11376 }
11377 
11429  Symbol indices,
11430  Shape shape) {
11431  return Operator("scatter_nd")
11432  .SetParam("shape", shape)
11433  .SetInput("data", data)
11434  .SetInput("indices", indices)
11435  .CreateSymbol();
11436 }
11437 
11460  Symbol rhs) {
11461  return Operator("broadcast_equal")
11462  .SetInput("lhs", lhs)
11463  .SetInput("rhs", rhs)
11464  .CreateSymbol();
11465 }
11466 
11489  Symbol rhs) {
11490  return Operator("broadcast_not_equal")
11491  .SetInput("lhs", lhs)
11492  .SetInput("rhs", rhs)
11493  .CreateSymbol();
11494 }
11495 
11518  Symbol rhs) {
11519  return Operator("broadcast_greater")
11520  .SetInput("lhs", lhs)
11521  .SetInput("rhs", rhs)
11522  .CreateSymbol();
11523 }
11524 
11547  Symbol rhs) {
11548  return Operator("broadcast_greater_equal")
11549  .SetInput("lhs", lhs)
11550  .SetInput("rhs", rhs)
11551  .CreateSymbol();
11552 }
11553 
11576  Symbol rhs) {
11577  return Operator("broadcast_lesser")
11578  .SetInput("lhs", lhs)
11579  .SetInput("rhs", rhs)
11580  .CreateSymbol();
11581 }
11582 
11605  Symbol rhs) {
11606  return Operator("broadcast_lesser_equal")
11607  .SetInput("lhs", lhs)
11608  .SetInput("rhs", rhs)
11609  .CreateSymbol();
11610 }
11611 
11634  Symbol rhs) {
11635  return Operator("broadcast_logical_and")
11636  .SetInput("lhs", lhs)
11637  .SetInput("rhs", rhs)
11638  .CreateSymbol();
11639 }
11640 
11663  Symbol rhs) {
11664  return Operator("broadcast_logical_or")
11665  .SetInput("lhs", lhs)
11666  .SetInput("rhs", rhs)
11667  .CreateSymbol();
11668 }
11669 
11692  Symbol rhs) {
11693  return Operator("broadcast_logical_xor")
11694  .SetInput("lhs", lhs)
11695  .SetInput("rhs", rhs)
11696  .CreateSymbol();
11697 }
11698 
11741 inline Symbol diag(Symbol data,
11742  dmlc::optional<int> k = dmlc::optional<int>(0)) {
11743  return Operator("diag")
11744  .SetParam("k", k)
11745  .SetInput("data", data)
11746  .CreateSymbol();
11747 }
11748 
11783 inline Symbol where(Symbol condition,
11784  Symbol x,
11785  Symbol y) {
11786  return Operator("where")
11787  .SetInput("condition", condition)
11788  .SetInput("x", x)
11789  .SetInput("y", y)
11790  .CreateSymbol();
11791 }
11792 
11818  mx_float scalar) {
11819  return Operator("smooth_l1")
11820  .SetParam("scalar", scalar)
11821  .SetInput("data", data)
11822  .CreateSymbol();
11823 }
11824 
11870  Cast_storageStype stype) {
11871  static const char *Cast_storageStypeValues[] = {
11872  "csr",
11873  "default",
11874  "row_sparse"
11875  };
11876  return Operator("cast_storage")
11877  .SetParam("stype", Cast_storageStypeValues[int(stype)])
11878  .SetInput("data", data)
11879  .CreateSymbol();
11880 }
11881 
11902 inline Symbol sin(Symbol data) {
11903  return Operator("sin")
11904  .SetInput("data", data)
11905  .CreateSymbol();
11906 }
11907 
11924 inline Symbol cos(Symbol data) {
11925  return Operator("cos")
11926  .SetInput("data", data)
11927  .CreateSymbol();
11928 }
11929 
11950 inline Symbol tan(Symbol data) {
11951  return Operator("tan")
11952  .SetInput("data", data)
11953  .CreateSymbol();
11954 }
11955 
11977 inline Symbol arcsin(Symbol data) {
11978  return Operator("arcsin")
11979  .SetInput("data", data)
11980  .CreateSymbol();
11981 }
11982 
12000 inline Symbol arccos(Symbol data) {
12001  return Operator("arccos")
12002  .SetInput("data", data)
12003  .CreateSymbol();
12004 }
12005 
12026 inline Symbol arctan(Symbol data) {
12027  return Operator("arctan")
12028  .SetInput("data", data)
12029  .CreateSymbol();
12030 }
12031 
12050 inline Symbol degrees(Symbol data) {
12051  return Operator("degrees")
12052  .SetInput("data", data)
12053  .CreateSymbol();
12054 }
12055 
12074 inline Symbol radians(Symbol data) {
12075  return Operator("radians")
12076  .SetInput("data", data)
12077  .CreateSymbol();
12078 }
12079 
12098 inline Symbol sinh(Symbol data) {
12099  return Operator("sinh")
12100  .SetInput("data", data)
12101  .CreateSymbol();
12102 }
12103 
12118 inline Symbol cosh(Symbol data) {
12119  return Operator("cosh")
12120  .SetInput("data", data)
12121  .CreateSymbol();
12122 }
12123 
12142 inline Symbol tanh(Symbol data) {
12143  return Operator("tanh")
12144  .SetInput("data", data)
12145  .CreateSymbol();
12146 }
12147 
12164 inline Symbol arcsinh(Symbol data) {
12165  return Operator("arcsinh")
12166  .SetInput("data", data)
12167  .CreateSymbol();
12168 }
12169 
12182 inline Symbol arccosh(Symbol data) {
12183  return Operator("arccosh")
12184  .SetInput("data", data)
12185  .CreateSymbol();
12186 }
12187 
12204 inline Symbol arctanh(Symbol data) {
12205  return Operator("arctanh")
12206  .SetInput("data", data)
12207  .CreateSymbol();
12208 }
12209 
12277 inline Symbol Pooling(Symbol data,
12278  Shape kernel = Shape(),
12280  bool global_pool = false,
12281  bool cudnn_off = false,
12283  Shape stride = Shape(),
12284  Shape pad = Shape(),
12285  dmlc::optional<int> p_value = dmlc::optional<int>(),
12286  dmlc::optional<bool> count_include_pad = dmlc::optional<bool>()) {
12287  static const char *PoolingPoolTypeValues[] = {
12288  "avg",
12289  "lp",
12290  "max",
12291  "sum"
12292  };
12293  static const char *PoolingPoolingConventionValues[] = {
12294  "full",
12295  "valid"
12296  };
12297  return Operator("Pooling")
12298  .SetParam("kernel", kernel)
12299  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
12300  .SetParam("global_pool", global_pool)
12301  .SetParam("cudnn_off", cudnn_off)
12302  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
12303  .SetParam("stride", stride)
12304  .SetParam("pad", pad)
12305  .SetParam("p_value", p_value)
12306  .SetParam("count_include_pad", count_include_pad)
12307  .SetInput("data", data)
12308  .CreateSymbol();
12309 }
12310 
12342 inline Symbol softmax(Symbol data,
12343  int axis = -1,
12344  dmlc::optional<double> temperature = dmlc::optional<double>()) {
12345  return Operator("softmax")
12346  .SetParam("axis", axis)
12347  .SetParam("temperature", temperature)
12348  .SetInput("data", data)
12349  .CreateSymbol();
12350 }
12351 
12375  int axis = -1,
12376  dmlc::optional<double> temperature = dmlc::optional<double>()) {
12377  return Operator("log_softmax")
12378  .SetParam("axis", axis)
12379  .SetParam("temperature", temperature)
12380  .SetInput("data", data)
12381  .CreateSymbol();
12382 }
12383 
12413  Symbol weight,
12414  Symbol bias,
12415  Shape kernel,
12416  uint32_t num_filter,
12417  Shape stride = Shape(),
12418  Shape dilate = Shape(),
12419  Shape pad = Shape(),
12420  Shape adj = Shape(),
12421  Shape target_shape = Shape(),
12422  uint32_t num_group = 1,
12423  uint64_t workspace = 512,
12424  bool no_bias = true,
12426  bool cudnn_off = false,
12428  static const char *DeconvolutionCudnnTuneValues[] = {
12429  "None",
12430  "fastest",
12431  "limited_workspace",
12432  "off"
12433  };
12434  static const char *DeconvolutionLayoutValues[] = {
12435  "None",
12436  "NCDHW",
12437  "NCHW",
12438  "NCW",
12439  "NDHWC",
12440  "NHWC"
12441  };
12442  return Operator("Deconvolution")
12443  .SetParam("kernel", kernel)
12444  .SetParam("num_filter", num_filter)
12445  .SetParam("stride", stride)
12446  .SetParam("dilate", dilate)
12447  .SetParam("pad", pad)
12448  .SetParam("adj", adj)
12449  .SetParam("target_shape", target_shape)
12450  .SetParam("num_group", num_group)
12451  .SetParam("workspace", workspace)
12452  .SetParam("no_bias", no_bias)
12453  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
12454  .SetParam("cudnn_off", cudnn_off)
12455  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
12456  .SetInput("data", data)
12457  .SetInput("weight", weight)
12458  .SetInput("bias", bias)
12459  .CreateSymbol();
12460 }
12461 
12481  ActivationActType act_type) {
12482  static const char *ActivationActTypeValues[] = {
12483  "relu",
12484  "sigmoid",
12485  "softrelu",
12486  "softsign",
12487  "tanh"
12488  };
12489  return Operator("Activation")
12490  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
12491  .SetInput("data", data)
12492  .CreateSymbol();
12493 }
12494 
12564  Symbol gamma,
12565  Symbol beta,
12566  Symbol moving_mean,
12567  Symbol moving_var,
12568  double eps = 0.001,
12569  mx_float momentum = 0.9,
12570  bool fix_gamma = true,
12571  bool use_global_stats = false,
12572  bool output_mean_var = false,
12573  int axis = 1,
12574  bool cudnn_off = false) {
12575  return Operator("BatchNorm")
12576  .SetParam("eps", eps)
12577  .SetParam("momentum", momentum)
12578  .SetParam("fix_gamma", fix_gamma)
12579  .SetParam("use_global_stats", use_global_stats)
12580  .SetParam("output_mean_var", output_mean_var)
12581  .SetParam("axis", axis)
12582  .SetParam("cudnn_off", cudnn_off)
12583  .SetInput("data", data)
12584  .SetInput("gamma", gamma)
12585  .SetInput("beta", beta)
12586  .SetInput("moving_mean", moving_mean)
12587  .SetInput("moving_var", moving_var)
12588  .CreateSymbol();
12589 }
12590 
12688  Symbol weight,
12689  Symbol bias,
12690  Shape kernel,
12691  uint32_t num_filter,
12692  Shape stride = Shape(),
12693  Shape dilate = Shape(),
12694  Shape pad = Shape(),
12695  uint32_t num_group = 1,
12696  uint64_t workspace = 1024,
12697  bool no_bias = false,
12699  bool cudnn_off = false,
12701  static const char *ConvolutionCudnnTuneValues[] = {
12702  "None",
12703  "fastest",
12704  "limited_workspace",
12705  "off"
12706  };
12707  static const char *ConvolutionLayoutValues[] = {
12708  "None",
12709  "NCDHW",
12710  "NCHW",
12711  "NCW",
12712  "NDHWC",
12713  "NHWC"
12714  };
12715  return Operator("Convolution")
12716  .SetParam("kernel", kernel)
12717  .SetParam("num_filter", num_filter)
12718  .SetParam("stride", stride)
12719  .SetParam("dilate", dilate)
12720  .SetParam("pad", pad)
12721  .SetParam("num_group", num_group)
12722  .SetParam("workspace", workspace)
12723  .SetParam("no_bias", no_bias)
12724  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
12725  .SetParam("cudnn_off", cudnn_off)
12726  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
12727  .SetInput("data", data)
12728  .SetInput("weight", weight)
12729  .SetInput("bias", bias)
12730  .CreateSymbol();
12731 }
12732 
12747 inline Symbol UpSampling(const std::vector<Symbol>& data,
12748  uint32_t scale,
12749  UpSamplingSampleType sample_type,
12750  int num_args,
12751  uint32_t num_filter = 0,
12753  uint64_t workspace = 512) {
12754  static const char *UpSamplingSampleTypeValues[] = {
12755  "bilinear",
12756  "nearest"
12757  };
12758  static const char *UpSamplingMultiInputModeValues[] = {
12759  "concat",
12760  "sum"
12761  };
12762  return Operator("UpSampling")
12763  .SetParam("scale", scale)
12764  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
12765  .SetParam("num_args", num_args)
12766  .SetParam("num_filter", num_filter)
12767  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
12768  .SetParam("workspace", workspace)
12769 (data)
12770  .CreateSymbol();
12771 }
12772 
12818 inline Symbol Concat(const std::vector<Symbol>& data,
12819  int num_args,
12820  int dim = 1) {
12821  return Operator("Concat")
12822  .SetParam("num_args", num_args)
12823  .SetParam("dim", dim)
12824 (data)
12825  .CreateSymbol();
12826 }
12827 
12866  Symbol gamma,
12867  Symbol beta,
12868  int axis = -1,
12869  mx_float eps = 1e-05,
12870  bool output_mean_var = false) {
12871  return Operator("LayerNorm")
12872  .SetParam("axis", axis)
12873  .SetParam("eps", eps)
12874  .SetParam("output_mean_var", output_mean_var)
12875  .SetInput("data", data)
12876  .SetInput("gamma", gamma)
12877  .SetInput("beta", beta)
12878  .CreateSymbol();
12879 }
12880 
12907 inline Symbol LRN(Symbol data,
12908  uint32_t nsize,
12909  mx_float alpha = 0.0001,
12910  mx_float beta = 0.75,
12911  mx_float knorm = 2) {
12912  return Operator("LRN")
12913  .SetParam("nsize", nsize)
12914  .SetParam("alpha", alpha)
12915  .SetParam("beta", beta)
12916  .SetParam("knorm", knorm)
12917  .SetInput("data", data)
12918  .CreateSymbol();
12919 }
12920 
12960 inline Symbol Dropout(Symbol data,
12961  mx_float p = 0.5,
12963  Shape axes = Shape()) {
12964  static const char *DropoutModeValues[] = {
12965  "always",
12966  "training"
12967  };
12968  return Operator("Dropout")
12969  .SetParam("p", p)
12970  .SetParam("mode", DropoutModeValues[int(mode)])
12971  .SetParam("axes", axes)
12972  .SetInput("data", data)
12973  .CreateSymbol();
12974 }
12975 
13010  static const char *SoftmaxActivationModeValues[] = {
13011  "channel",
13012  "instance"
13013  };
13014  return Operator("SoftmaxActivation")
13015  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
13016  .SetInput("data", data)
13017  .CreateSymbol();
13018 }
13019 
13063  Symbol weight,
13064  Symbol bias,
13065  int num_hidden,
13066  bool no_bias = false,
13067  bool flatten = true) {
13068  return Operator("FullyConnected")
13069  .SetParam("num_hidden", num_hidden)
13070  .SetParam("no_bias", no_bias)
13071  .SetParam("flatten", flatten)
13072  .SetInput("data", data)
13073  .SetInput("weight", weight)
13074  .SetInput("bias", bias)
13075  .CreateSymbol();
13076 }
13077 
13173 inline Symbol Pad(Symbol data,
13174  PadMode mode,
13175  Shape pad_width,
13176  double constant_value = 0) {
13177  static const char *PadModeValues[] = {
13178  "constant",
13179  "edge",
13180  "reflect"
13181  };
13182  return Operator("Pad")
13183  .SetParam("mode", PadModeValues[int(mode)])
13184  .SetParam("pad_width", pad_width)
13185  .SetParam("constant_value", constant_value)
13186  .SetInput("data", data)
13187  .CreateSymbol();
13188 }
13189 
13220  Symbol gamma,
13222  mx_float slope = 0.25,
13223  mx_float lower_bound = 0.125,
13224  mx_float upper_bound = 0.334) {
13225  static const char *LeakyReLUActTypeValues[] = {
13226  "elu",
13227  "leaky",
13228  "prelu",
13229  "rrelu",
13230  "selu"
13231  };
13232  return Operator("LeakyReLU")
13233  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
13234  .SetParam("slope", slope)
13235  .SetParam("lower_bound", lower_bound)
13236  .SetParam("upper_bound", upper_bound)
13237  .SetInput("data", data)
13238  .SetInput("gamma", gamma)
13239  .CreateSymbol();
13240 }
13241 
13269 inline Symbol SwapAxis(Symbol data,
13270  uint32_t dim1 = 0,
13271  uint32_t dim2 = 0) {
13272  return Operator("SwapAxis")
13273  .SetParam("dim1", dim1)
13274  .SetParam("dim2", dim2)
13275  .SetInput("data", data)
13276  .CreateSymbol();
13277 }
13278 
13339  Symbol gamma,
13340  Symbol beta,
13341  mx_float eps = 0.001,
13342  mx_float momentum = 0.9,
13343  bool fix_gamma = true,
13344  bool use_global_stats = false,
13345  bool output_mean_var = false) {
13346  return Operator("BatchNorm_v1")
13347  .SetParam("eps", eps)
13348  .SetParam("momentum", momentum)
13349  .SetParam("fix_gamma", fix_gamma)
13350  .SetParam("use_global_stats", use_global_stats)
13351  .SetParam("output_mean_var", output_mean_var)
13352  .SetInput("data", data)
13353  .SetInput("gamma", gamma)
13354  .SetInput("beta", beta)
13355  .CreateSymbol();
13356 }
13357 
13395  Symbol label) {
13396  return Operator("softmax_cross_entropy")
13397  .SetInput("data", data)
13398  .SetInput("label", label)
13399  .CreateSymbol();
13400 }
13401 
13431  Symbol label,
13432  mx_float grad_scale = 1) {
13433  return Operator("LinearRegressionOutput")
13434  .SetParam("grad_scale", grad_scale)
13435  .SetInput("data", data)
13436  .SetInput("label", label)
13437  .CreateSymbol();
13438 }
13439 
13470  Symbol label,
13471  mx_float grad_scale = 1) {
13472  return Operator("MAERegressionOutput")
13473  .SetParam("grad_scale", grad_scale)
13474  .SetInput("data", data)
13475  .SetInput("label", label)
13476  .CreateSymbol();
13477 }
13478 
13515  Symbol label,
13516  mx_float grad_scale = 1) {
13517  return Operator("LogisticRegressionOutput")
13518  .SetParam("grad_scale", grad_scale)
13519  .SetInput("data", data)
13520  .SetInput("label", label)
13521  .CreateSymbol();
13522 }
13523 
13533  mx_float sparseness_target = 0.1,
13534  mx_float penalty = 0.001,
13535  mx_float momentum = 0.9) {
13536  return Operator("IdentityAttachKLSparseReg")
13537  .SetParam("sparseness_target", sparseness_target)
13538  .SetParam("penalty", penalty)
13539  .SetParam("momentum", momentum)
13540  .SetInput("data", data)
13541  .CreateSymbol();
13542 }
13543 
13572  Symbol grad,
13573  mx_float lr,
13574  mx_float wd = 0,
13575  mx_float rescale_grad = 1,
13576  mx_float clip_gradient = -1) {
13577  return Operator("signsgd_update")
13578  .SetParam("lr", lr)
13579  .SetParam("wd", wd)
13580  .SetParam("rescale_grad", rescale_grad)
13581  .SetParam("clip_gradient", clip_gradient)
13582  .SetInput("weight", weight)
13583  .SetInput("grad", grad)
13584  .CreateSymbol();
13585 }
13586 
13621  Symbol grad,
13622  Symbol mom,
13623  mx_float lr,
13624  mx_float momentum = 0,
13625  mx_float wd = 0,
13626  mx_float rescale_grad = 1,
13627  mx_float clip_gradient = -1,
13628  mx_float wd_lh = 0) {
13629  return Operator("signum_update")
13630  .SetParam("lr", lr)
13631  .SetParam("momentum", momentum)
13632  .SetParam("wd", wd)
13633  .SetParam("rescale_grad", rescale_grad)
13634  .SetParam("clip_gradient", clip_gradient)
13635  .SetParam("wd_lh", wd_lh)
13636  .SetInput("weight", weight)
13637  .SetInput("grad", grad)
13638  .SetInput("mom", mom)
13639  .CreateSymbol();
13640 }
13641 
13669 inline Symbol sgd_update(Symbol weight,
13670  Symbol grad,
13671  mx_float lr,
13672  mx_float wd = 0,
13673  mx_float rescale_grad = 1,
13674  mx_float clip_gradient = -1,
13675  bool lazy_update = true) {
13676  return Operator("sgd_update")
13677  .SetParam("lr", lr)
13678  .SetParam("wd", wd)
13679  .SetParam("rescale_grad", rescale_grad)
13680  .SetParam("clip_gradient", clip_gradient)
13681  .SetParam("lazy_update", lazy_update)
13682  .SetInput("weight", weight)
13683  .SetInput("grad", grad)
13684  .CreateSymbol();
13685 }
13686 
13731  Symbol grad,
13732  Symbol mom,
13733  mx_float lr,
13734  mx_float momentum = 0,
13735  mx_float wd = 0,
13736  mx_float rescale_grad = 1,
13737  mx_float clip_gradient = -1,
13738  bool lazy_update = true) {
13739  return Operator("sgd_mom_update")
13740  .SetParam("lr", lr)
13741  .SetParam("momentum", momentum)
13742  .SetParam("wd", wd)
13743  .SetParam("rescale_grad", rescale_grad)
13744  .SetParam("clip_gradient", clip_gradient)
13745  .SetParam("lazy_update", lazy_update)
13746  .SetInput("weight", weight)
13747  .SetInput("grad", grad)
13748  .SetInput("mom", mom)
13749  .CreateSymbol();
13750 }
13751 
13767  Symbol grad,
13768  Symbol weight32,
13769  mx_float lr,
13770  mx_float wd = 0,
13771  mx_float rescale_grad = 1,
13772  mx_float clip_gradient = -1,
13773  bool lazy_update = true) {
13774  return Operator("mp_sgd_update")
13775  .SetParam("lr", lr)
13776  .SetParam("wd", wd)
13777  .SetParam("rescale_grad", rescale_grad)
13778  .SetParam("clip_gradient", clip_gradient)
13779  .SetParam("lazy_update", lazy_update)
13780  .SetInput("weight", weight)
13781  .SetInput("grad", grad)
13782  .SetInput("weight32", weight32)
13783  .CreateSymbol();
13784 }
13785 
13803  Symbol grad,
13804  Symbol mom,
13805  Symbol weight32,
13806  mx_float lr,
13807  mx_float momentum = 0,
13808  mx_float wd = 0,
13809  mx_float rescale_grad = 1,
13810  mx_float clip_gradient = -1,
13811  bool lazy_update = true) {
13812  return Operator("mp_sgd_mom_update")
13813  .SetParam("lr", lr)
13814  .SetParam("momentum", momentum)
13815  .SetParam("wd", wd)
13816  .SetParam("rescale_grad", rescale_grad)
13817  .SetParam("clip_gradient", clip_gradient)
13818  .SetParam("lazy_update", lazy_update)
13819  .SetInput("weight", weight)
13820  .SetInput("grad", grad)
13821  .SetInput("mom", mom)
13822  .SetInput("weight32", weight32)
13823  .CreateSymbol();
13824 }
13825 
13860 inline Symbol ftml_update(Symbol weight,
13861  Symbol grad,
13862  Symbol d,
13863  Symbol v,
13864  Symbol z,
13865  mx_float lr,
13866  int t,
13867  mx_float beta1 = 0.6,
13868  mx_float beta2 = 0.999,
13869  double epsilon = 1e-08,
13870  mx_float wd = 0,
13871  mx_float rescale_grad = 1,
13872  mx_float clip_grad = -1) {
13873  return Operator("ftml_update")
13874  .SetParam("lr", lr)
13875  .SetParam("t", t)
13876  .SetParam("beta1", beta1)
13877  .SetParam("beta2", beta2)
13878  .SetParam("epsilon", epsilon)
13879  .SetParam("wd", wd)
13880  .SetParam("rescale_grad", rescale_grad)
13881  .SetParam("clip_grad", clip_grad)
13882  .SetInput("weight", weight)
13883  .SetInput("grad", grad)
13884  .SetInput("d", d)
13885  .SetInput("v", v)
13886  .SetInput("z", z)
13887  .CreateSymbol();
13888 }
13889 
13938 inline Symbol adam_update(Symbol weight,
13939  Symbol grad,
13940  Symbol mean,
13941  Symbol var,
13942  mx_float lr,
13943  mx_float beta1 = 0.9,
13944  mx_float beta2 = 0.999,
13945  mx_float epsilon = 1e-08,
13946  mx_float wd = 0,
13947  mx_float rescale_grad = 1,
13948  mx_float clip_gradient = -1,
13949  bool lazy_update = true) {
13950  return Operator("adam_update")
13951  .SetParam("lr", lr)
13952  .SetParam("beta1", beta1)
13953  .SetParam("beta2", beta2)
13954  .SetParam("epsilon", epsilon)
13955  .SetParam("wd", wd)
13956  .SetParam("rescale_grad", rescale_grad)
13957  .SetParam("clip_gradient", clip_gradient)
13958  .SetParam("lazy_update", lazy_update)
13959  .SetInput("weight", weight)
13960  .SetInput("grad", grad)
13961  .SetInput("mean", mean)
13962  .SetInput("var", var)
13963  .CreateSymbol();
13964 }
13965 
14019  Symbol grad,
14020  Symbol n,
14021  mx_float lr,
14022  mx_float gamma1 = 0.95,
14023  mx_float epsilon = 1e-08,
14024  mx_float wd = 0,
14025  mx_float rescale_grad = 1,
14026  mx_float clip_gradient = -1,
14027  mx_float clip_weights = -1) {
14028  return Operator("rmsprop_update")
14029  .SetParam("lr", lr)
14030  .SetParam("gamma1", gamma1)
14031  .SetParam("epsilon", epsilon)
14032  .SetParam("wd", wd)
14033  .SetParam("rescale_grad", rescale_grad)
14034  .SetParam("clip_gradient", clip_gradient)
14035  .SetParam("clip_weights", clip_weights)
14036  .SetInput("weight", weight)
14037  .SetInput("grad", grad)
14038  .SetInput("n", n)
14039  .CreateSymbol();
14040 }
14041 
14087  Symbol grad,
14088  Symbol n,
14089  Symbol g,
14090  Symbol delta,
14091  mx_float lr,
14092  mx_float gamma1 = 0.95,
14093  mx_float gamma2 = 0.9,
14094  mx_float epsilon = 1e-08,
14095  mx_float wd = 0,
14096  mx_float rescale_grad = 1,
14097  mx_float clip_gradient = -1,
14098  mx_float clip_weights = -1) {
14099  return Operator("rmspropalex_update")
14100  .SetParam("lr", lr)
14101  .SetParam("gamma1", gamma1)
14102  .SetParam("gamma2", gamma2)
14103  .SetParam("epsilon", epsilon)
14104  .SetParam("wd", wd)
14105  .SetParam("rescale_grad", rescale_grad)
14106  .SetParam("clip_gradient", clip_gradient)
14107  .SetParam("clip_weights", clip_weights)
14108  .SetInput("weight", weight)
14109  .SetInput("grad", grad)
14110  .SetInput("n", n)
14111  .SetInput("g", g)
14112  .SetInput("delta", delta)
14113  .CreateSymbol();
14114 }
14115 
14154 inline Symbol ftrl_update(Symbol weight,
14155  Symbol grad,
14156  Symbol z,
14157  Symbol n,
14158  mx_float lr,
14159  mx_float lamda1 = 0.01,
14160  mx_float beta = 1,
14161  mx_float wd = 0,
14162  mx_float rescale_grad = 1,
14163  mx_float clip_gradient = -1) {
14164  return Operator("ftrl_update")
14165  .SetParam("lr", lr)
14166  .SetParam("lamda1", lamda1)
14167  .SetParam("beta", beta)
14168  .SetParam("wd", wd)
14169  .SetParam("rescale_grad", rescale_grad)
14170  .SetParam("clip_gradient", clip_gradient)
14171  .SetInput("weight", weight)
14172  .SetInput("grad", grad)
14173  .SetInput("z", z)
14174  .SetInput("n", n)
14175  .CreateSymbol();
14176 }
14177 
14249  int num_outputs,
14250  int axis = 1,
14251  bool squeeze_axis = false) {
14252  return Operator("SliceChannel")
14253  .SetParam("num_outputs", num_outputs)
14254  .SetParam("axis", axis)
14255  .SetParam("squeeze_axis", squeeze_axis)
14256  .SetInput("data", data)
14257  .CreateSymbol();
14258 }
14259 
14310  Symbol gamma,
14311  Symbol beta,
14312  mx_float eps = 0.001) {
14313  return Operator("InstanceNorm")
14314  .SetParam("eps", eps)
14315  .SetInput("data", data)
14316  .SetInput("gamma", gamma)
14317  .SetInput("beta", beta)
14318  .CreateSymbol();
14319 }
14320 
14331  GridGeneratorTransformType transform_type,
14332  Shape target_shape = Shape(0,0)) {
14333  static const char *GridGeneratorTransformTypeValues[] = {
14334  "affine",
14335  "warp"
14336  };
14337  return Operator("GridGenerator")
14338  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
14339  .SetParam("target_shape", target_shape)
14340  .SetInput("data", data)
14341  .CreateSymbol();
14342 }
14343 
14395  Shape kernel = Shape(),
14397  bool global_pool = false,
14399  Shape stride = Shape(),
14400  Shape pad = Shape()) {
14401  static const char *Pooling_v1PoolTypeValues[] = {
14402  "avg",
14403  "max",
14404  "sum"
14405  };
14406  static const char *Pooling_v1PoolingConventionValues[] = {
14407  "full",
14408  "valid"
14409  };
14410  return Operator("Pooling_v1")
14411  .SetParam("kernel", kernel)
14412  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
14413  .SetParam("global_pool", global_pool)
14414  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
14415  .SetParam("stride", stride)
14416  .SetParam("pad", pad)
14417  .SetInput("data", data)
14418  .CreateSymbol();
14419 }
14420 
14486 inline Symbol RNN(Symbol data,
14487  Symbol parameters,
14488  Symbol state,
14489  Symbol state_cell,
14490  uint32_t state_size,
14491  uint32_t num_layers,
14492  RNNMode mode,
14493  bool bidirectional = false,
14494  mx_float p = 0,
14495  bool state_outputs = false,
14496  dmlc::optional<int> projection_size = dmlc::optional<int>(),
14497  dmlc::optional<double> lstm_state_clip_min = dmlc::optional<double>(),
14498  dmlc::optional<double> lstm_state_clip_max = dmlc::optional<double>(),
14499  bool lstm_state_clip_nan = false) {
14500  static const char *RNNModeValues[] = {
14501  "gru",
14502  "lstm",
14503  "rnn_relu",
14504  "rnn_tanh"
14505  };
14506  return Operator("RNN")
14507  .SetParam("state_size", state_size)
14508  .SetParam("num_layers", num_layers)
14509  .SetParam("mode", RNNModeValues[int(mode)])
14510  .SetParam("bidirectional", bidirectional)
14511  .SetParam("p", p)
14512  .SetParam("state_outputs", state_outputs)
14513  .SetParam("projection_size", projection_size)
14514  .SetParam("lstm_state_clip_min", lstm_state_clip_min)
14515  .SetParam("lstm_state_clip_max", lstm_state_clip_max)
14516  .SetParam("lstm_state_clip_nan", lstm_state_clip_nan)
14517  .SetInput("data", data)
14518  .SetInput("parameters", parameters)
14519  .SetInput("state", state)
14520  .SetInput("state_cell", state_cell)
14521  .CreateSymbol();
14522 }
14523 
14554  Symbol weight,
14555  Symbol bias,
14556  Shape kernel,
14557  uint32_t num_filter,
14558  Shape stride = Shape(),
14559  Shape dilate = Shape(),
14560  Shape pad = Shape(),
14561  uint32_t num_group = 1,
14562  uint64_t workspace = 1024,
14563  bool no_bias = false,
14565  bool cudnn_off = false,
14567  static const char *Convolution_v1CudnnTuneValues[] = {
14568  "None",
14569  "fastest",
14570  "limited_workspace",
14571  "off"
14572  };
14573  static const char *Convolution_v1LayoutValues[] = {
14574  "None",
14575  "NCDHW",
14576  "NCHW",
14577  "NDHWC",
14578  "NHWC"
14579  };
14580  return Operator("Convolution_v1")
14581  .SetParam("kernel", kernel)
14582  .SetParam("num_filter", num_filter)
14583  .SetParam("stride", stride)
14584  .SetParam("dilate", dilate)
14585  .SetParam("pad", pad)
14586  .SetParam("num_group", num_group)
14587  .SetParam("workspace", workspace)
14588  .SetParam("no_bias", no_bias)
14589  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
14590  .SetParam("cudnn_off", cudnn_off)
14591  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
14592  .SetInput("data", data)
14593  .SetInput("weight", weight)
14594  .SetInput("bias", bias)
14595  .CreateSymbol();
14596 }
14597 
14617 inline Symbol Crop(const std::vector<Symbol>& data,
14618  int num_args,
14619  Shape offset = Shape(0,0),
14620  Shape h_w = Shape(0,0),
14621  bool center_crop = false) {
14622  return Operator("Crop")
14623  .SetParam("num_args", num_args)
14624  .SetParam("offset", offset)
14625  .SetParam("h_w", h_w)
14626  .SetParam("center_crop", center_crop)
14627 (data)
14628  .CreateSymbol();
14629 }
14630 
14707  Symbol sequence_length,
14708  bool use_sequence_length = false,
14709  int axis = 0) {
14710  return Operator("SequenceReverse")
14711  .SetParam("use_sequence_length", use_sequence_length)
14712  .SetParam("axis", axis)
14713  .SetInput("data", data)
14714  .SetInput("sequence_length", sequence_length)
14715  .CreateSymbol();
14716 }
14717 
14728  Symbol loc,
14729  SpatialTransformerTransformType transform_type,
14730  SpatialTransformerSamplerType sampler_type,
14731  Shape target_shape = Shape(0,0)) {
14732  static const char *SpatialTransformerTransformTypeValues[] = {
14733  "affine"
14734  };
14735  static const char *SpatialTransformerSamplerTypeValues[] = {
14736  "bilinear"
14737  };
14738  return Operator("SpatialTransformer")
14739  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
14740  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
14741  .SetParam("target_shape", target_shape)
14742  .SetInput("data", data)
14743  .SetInput("loc", loc)
14744  .CreateSymbol();
14745 }
14746 
14842  Symbol label,
14843  mx_float grad_scale = 1,
14844  mx_float ignore_label = -1,
14845  bool multi_output = false,
14846  bool use_ignore = false,
14847  bool preserve_shape = false,
14849  bool out_grad = false,
14850  mx_float smooth_alpha = 0) {
14851  static const char *SoftmaxOutputNormalizationValues[] = {
14852  "batch",
14853  "null",
14854  "valid"
14855  };
14856  return Operator("SoftmaxOutput")
14857  .SetParam("grad_scale", grad_scale)
14858  .SetParam("ignore_label", ignore_label)
14859  .SetParam("multi_output", multi_output)
14860  .SetParam("use_ignore", use_ignore)
14861  .SetParam("preserve_shape", preserve_shape)
14862  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
14863  .SetParam("out_grad", out_grad)
14864  .SetParam("smooth_alpha", smooth_alpha)
14865  .SetInput("data", data)
14866  .SetInput("label", label)
14867  .CreateSymbol();
14868 }
14869 
14896 inline Symbol Softmax(Symbol data,
14897  mx_float grad_scale = 1,
14898  mx_float ignore_label = -1,
14899  bool multi_output = false,
14900  bool use_ignore = false,
14901  bool preserve_shape = false,
14903  bool out_grad = false,
14904  mx_float smooth_alpha = 0) {
14905  static const char *SoftmaxNormalizationValues[] = {
14906  "batch",
14907  "null",
14908  "valid"
14909  };
14910  return Operator("Softmax")
14911  .SetParam("grad_scale", grad_scale)
14912  .SetParam("ignore_label", ignore_label)
14913  .SetParam("multi_output", multi_output)
14914  .SetParam("use_ignore", use_ignore)
14915  .SetParam("preserve_shape", preserve_shape)
14916  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
14917  .SetParam("out_grad", out_grad)
14918  .SetParam("smooth_alpha", smooth_alpha)
14919  .SetInput("data", data)
14920  .CreateSymbol();
14921 }
14922 
15003  Symbol grid) {
15004  return Operator("BilinearSampler")
15005  .SetInput("data", data)
15006  .SetInput("grid", grid)
15007  .CreateSymbol();
15008 }
15009 
15066  Symbol rois,
15067  Shape pooled_size,
15068  mx_float spatial_scale) {
15069  return Operator("ROIPooling")
15070  .SetParam("pooled_size", pooled_size)
15071  .SetParam("spatial_scale", spatial_scale)
15072  .SetInput("data", data)
15073  .SetInput("rois", rois)
15074  .CreateSymbol();
15075 }
15076 
15132  Symbol sequence_length,
15133  bool use_sequence_length = false,
15134  int axis = 0) {
15135  return Operator("SequenceLast")
15136  .SetParam("use_sequence_length", use_sequence_length)
15137  .SetParam("axis", axis)
15138  .SetInput("data", data)
15139  .SetInput("sequence_length", sequence_length)
15140  .CreateSymbol();
15141 }
15142 
15205  mx_float eps = 1e-10,
15207  static const char *L2NormalizationModeValues[] = {
15208  "channel",
15209  "instance",
15210  "spatial"
15211  };
15212  return Operator("L2Normalization")
15213  .SetParam("eps", eps)
15214  .SetParam("mode", L2NormalizationModeValues[int(mode)])
15215  .SetInput("data", data)
15216  .CreateSymbol();
15217 }
15218 
15252 inline Symbol MakeLoss(Symbol data,
15253  mx_float grad_scale = 1,
15254  mx_float valid_thresh = 0,
15256  static const char *MakeLossNormalizationValues[] = {
15257  "batch",
15258  "null",
15259  "valid"
15260  };
15261  return Operator("MakeLoss")
15262  .SetParam("grad_scale", grad_scale)
15263  .SetParam("valid_thresh", valid_thresh)
15264  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
15265  .SetInput("data", data)
15266  .CreateSymbol();
15267 }
15268 
15284  Symbol label,
15285  mx_float margin = 1,
15286  mx_float regularization_coefficient = 1,
15287  bool use_linear = false) {
15288  return Operator("SVMOutput")
15289  .SetParam("margin", margin)
15290  .SetParam("regularization_coefficient", regularization_coefficient)
15291  .SetParam("use_linear", use_linear)
15292  .SetInput("data", data)
15293  .SetInput("label", label)
15294  .CreateSymbol();
15295 }
15296 
15346  Symbol data2,
15347  uint32_t kernel_size = 1,
15348  uint32_t max_displacement = 1,
15349  uint32_t stride1 = 1,
15350  uint32_t stride2 = 1,
15351  uint32_t pad_size = 0,
15352  bool is_multiply = true) {
15353  return Operator("Correlation")
15354  .SetParam("kernel_size", kernel_size)
15355  .SetParam("max_displacement", max_displacement)
15356  .SetParam("stride1", stride1)
15357  .SetParam("stride2", stride2)
15358  .SetParam("pad_size", pad_size)
15359  .SetParam("is_multiply", is_multiply)
15360  .SetInput("data1", data1)
15361  .SetInput("data2", data2)
15362  .CreateSymbol();
15363 }
15364 
15443  Symbol sequence_length,
15444  bool use_sequence_length = false,
15445  mx_float value = 0,
15446  int axis = 0) {
15447  return Operator("SequenceMask")
15448  .SetParam("use_sequence_length", use_sequence_length)
15449  .SetParam("value", value)
15450  .SetParam("axis", axis)
15451  .SetInput("data", data)
15452  .SetInput("sequence_length", sequence_length)
15453  .CreateSymbol();
15454 }
15455 
15464  Symbol rhs) {
15465  return Operator("choose_element_0index")
15466  .SetInput("lhs", lhs)
15467  .SetInput("rhs", rhs)
15468  .CreateSymbol();
15469 }
15470 
15480  Symbol mhs,
15481  Symbol rhs) {
15482  return Operator("fill_element_0index")
15483  .SetInput("lhs", lhs)
15484  .SetInput("mhs", mhs)
15485  .SetInput("rhs", rhs)
15486  .CreateSymbol();
15487 }
15488 
15489 } //namespace cpp
15490 } //namespace mxnet
15491 #endif // MXNET_CPP_OP_H_
Symbol Convolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, ConvolutionCudnnTune cudnn_tune=ConvolutionCudnnTune::kNone, bool cudnn_off=false, ConvolutionLayout layout=ConvolutionLayout::kNone)
Definition: op.h:5047
Symbol Pooling(const std::string &symbol_name, Symbol data, Shape kernel=Shape(), PoolingPoolType pool_type=PoolingPoolType::kMax, bool global_pool=false, bool cudnn_off=false, PoolingPoolingConvention pooling_convention=PoolingPoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape(), dmlc::optional< int > p_value=dmlc::optional< int >(), dmlc::optional< bool > count_include_pad=dmlc::optional< bool >())
Definition: op.h:4574
Symbol mp_sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, Symbol weight32, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, bool lazy_update=true)
Definition: op.h:6255
Symbol fix(const std::string &symbol_name, Symbol data)
Definition: op.h:2147
Symbol Crop(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, Shape offset=Shape(0, 0), Shape h_w=Shape(0, 0), bool center_crop=false)
Definition: op.h:7153
Symbol broadcast_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1116
Symbol arcsin(const std::string &symbol_name, Symbol data)
Definition: op.h:4236
Symbol FullyConnected(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, int num_hidden, bool no_bias=false, bool flatten=true)
Definition: op.h:5466
Symbol arccosh(const std::string &symbol_name, Symbol data)
Definition: op.h:4459
Symbol arctan(const std::string &symbol_name, Symbol data)
Definition: op.h:4289
Symbol SwapAxis(const std::string &symbol_name, Symbol data, uint32_t dim1=0, uint32_t dim2=0)
Definition: op.h:5698
Symbol cast_storage(const std::string &symbol_name, Symbol data, Cast_storageStype stype)
Definition: op.h:4120
Symbol add_n(const std::string &symbol_name, const std::vector< Symbol > &args)
Definition: op.h:1213
Symbol log1p(const std::string &symbol_name, Symbol data)
Definition: op.h:2402
SoftmaxActivationMode
Definition: op.h:5372
Symbol SpatialTransformer(const std::string &symbol_name, Symbol data, Symbol loc, SpatialTransformerTransformType transform_type, SpatialTransformerSamplerType sampler_type, Shape target_shape=Shape(0, 0))
Definition: op.h:7279
Symbol argsort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true, ArgsortDtype dtype=ArgsortDtype::kFloat32)
Definition: op.h:3132
Symbol slice(const std::string &symbol_name, Symbol data, Shape begin, Shape end, Shape step=Shape())
Definition: op.h:481
Symbol exp(const std::string &symbol_name, Symbol data)
Definition: op.h:2313
Symbol transpose(const std::string &symbol_name, Symbol data, Shape axes=Shape())
Definition: op.h:393
Symbol RNN(const std::string &symbol_name, Symbol data, Symbol parameters, Symbol state, Symbol state_cell, uint32_t state_size, uint32_t num_layers, RNNMode mode, bool bidirectional=false, mx_float p=0, bool state_outputs=false, dmlc::optional< int > projection_size=dmlc::optional< int >(), dmlc::optional< double > lstm_state_clip_min=dmlc::optional< double >(), dmlc::optional< double > lstm_state_clip_max=dmlc::optional< double >(), bool lstm_state_clip_nan=false)
Definition: op.h:6991
Symbol clip(const std::string &symbol_name, Symbol data, mx_float a_min, mx_float a_max)
Definition: op.h:645
Symbol elemwise_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3246
Symbol Embedding(const std::string &symbol_name, Symbol data, Symbol weight, int input_dim, int output_dim, EmbeddingDtype dtype=EmbeddingDtype::kFloat32, bool sparse_grad=false)
Definition: op.h:3330
Symbol ROIPooling(const std::string &symbol_name, Symbol data, Symbol rois, Shape pooled_size, mx_float spatial_scale)
Definition: op.h:7641
Symbol broadcast_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1151
Convolution_v1Layout
Definition: op.h:7049
Symbol argmin(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1293
Symbol dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false, DotForwardStype forward_stype=DotForwardStype::kNone)
Definition: op.h:1486
Symbol topk(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), int k=1, TopkRetTyp ret_typ=TopkRetTyp::kIndices, bool is_ascend=false, TopkDtype dtype=TopkDtype::kFloat32)
Definition: op.h:3016
Symbol SequenceReverse(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, int axis=0)
Definition: op.h:7244
Symbol broadcast_lesser(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3802
Symbol fill_element_0index(const std::string &symbol_name, Symbol lhs, Symbol mhs, Symbol rhs)
Definition: op.h:8089
Symbol Convolution_v1(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, Convolution_v1CudnnTune cudnn_tune=Convolution_v1CudnnTune::kNone, bool cudnn_off=false, Convolution_v1Layout layout=Convolution_v1Layout::kNone)
Definition: op.h:7087
Symbol broadcast_not_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3709
TakeMode
Definition: op.h:3360
Symbol SequenceLast(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, int axis=0)
Definition: op.h:7709
Symbol ftrl_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol z, Symbol n, mx_float lr, mx_float lamda1=0.01, mx_float beta=1, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:6617
Symbol reciprocal(const std::string &symbol_name, Symbol data)
Definition: op.h:1919
TopkRetTyp
Definition: op.h:2956
namespace of mxnet
Definition: base.h:118
Symbol reshape_like(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1770
Pooling_v1PoolingConvention
Definition: op.h:6832
Symbol broadcast_lesser_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3833
Operator & SetInput(const std::string &name, Symbol symbol)
add an input symbol
Symbol InstanceNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001)
Definition: op.h:6776
Symbol sign(const std::string &symbol_name, Symbol data)
Definition: op.h:1973
GridGeneratorTransformType
Definition: op.h:6792
Cast_storageStype
Definition: op.h:4069
Symbol log_softmax(const std::string &symbol_name, Symbol data, int axis=-1, dmlc::optional< double > temperature=dmlc::optional< double >())
Definition: op.h:4675
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
Symbol ones_like(const std::string &symbol_name, Symbol data)
Definition: op.h:1001
RNNMode
Definition: op.h:6918
PadMode
Definition: op.h:5486
Symbol smooth_l1(const std::string &symbol_name, Symbol data, mx_float scalar)
Definition: op.h:4058
Symbol where(const std::string &symbol_name, Symbol condition, Symbol x, Symbol y)
Definition: op.h:4022
Symbol space_to_depth(const std::string &symbol_name, Symbol data, int block_size)
Definition: op.h:945
Symbol expm1(const std::string &symbol_name, Symbol data)
Definition: op.h:2427
Symbol elemwise_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3171
PoolingPoolType
Definition: op.h:4492
Symbol relu(const std::string &symbol_name, Symbol data)
Definition: op.h:1583
Symbol reverse(const std::string &symbol_name, Symbol data, Shape axis)
Definition: op.h:778
Symbol rsqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:2233
Symbol mp_sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol weight32, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, bool lazy_update=true)
Definition: op.h:6217
Symbol batch_dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false, Batch_dotForwardStype forward_stype=Batch_dotForwardStype::kNone)
Definition: op.h:1543
SpatialTransformerTransformType
Definition: op.h:7259
ActivationActType
Definition: op.h:4788
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:2207
Symbol Softmax(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxNormalization normalization=SoftmaxNormalization::kNull, bool out_grad=false, mx_float smooth_alpha=0)
Definition: op.h:7468
Symbol rint(const std::string &symbol_name, Symbol data)
Definition: op.h:2031
Symbol IdentityAttachKLSparseReg(const std::string &symbol_name, Symbol data, mx_float sparseness_target=0.1, mx_float penalty=0.001, mx_float momentum=0.9)
Definition: op.h:5973
Symbol sinh(const std::string &symbol_name, Symbol data)
Definition: op.h:4367
Symbol scatter_nd(const std::string &symbol_name, Symbol data, Symbol indices, Shape shape)
Definition: op.h:3645
Symbol broadcast_greater_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3771
Symbol LRN(const std::string &symbol_name, Symbol data, uint32_t nsize, mx_float alpha=0.0001, mx_float beta=0.75, mx_float knorm=2)
Definition: op.h:5290
Symbol LayerNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, int axis=-1, mx_float eps=1e-05, bool output_mean_var=false)
Definition: op.h:5246
Symbol arcsinh(const std::string &symbol_name, Symbol data)
Definition: op.h:4439
Symbol MAERegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:5906
Symbol SliceChannel(const std::string &symbol_name, Symbol data, int num_outputs, int axis=1, bool squeeze_axis=false)
Definition: op.h:6713
PoolingPoolingConvention
Definition: op.h:4501
Symbol broadcast_minimum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:180
ArgsortDtype
Definition: op.h:3092
Symbol broadcast_maximum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:147
Symbol Cast(const std::string &symbol_name, Symbol data, CastDtype dtype)
Definition: op.h:1864
DeconvolutionLayout
Definition: op.h:4697
Symbol trunc(const std::string &symbol_name, Symbol data)
Definition: op.h:2119
Pooling_v1PoolType
Definition: op.h:6824
Symbol round(const std::string &symbol_name, Symbol data)
Definition: op.h:2000
Symbol adam_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mean, Symbol var, mx_float lr, mx_float beta1=0.9, mx_float beta2=0.999, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, bool lazy_update=true)
Definition: op.h:6395
Symbol Dropout(const std::string &symbol_name, Symbol data, mx_float p=0.5, DropoutMode mode=DropoutMode::kTraining, Shape axes=Shape())
Definition: op.h:5352
Symbol squeeze(const std::string &symbol_name, const std::vector< Symbol > &data, dmlc::optional< Shape > axis=dmlc::optional< Shape >())
Definition: op.h:843
TopkDtype
Definition: op.h:2965
Symbol broadcast_logical_or(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3895
Symbol khatri_rao(const std::string &symbol_name, const std::vector< Symbol > &args)
Definition: op.h:62
Symbol cos(const std::string &symbol_name, Symbol data)
Definition: op.h:4179
Symbol L2Normalization(const std::string &symbol_name, Symbol data, mx_float eps=1e-10, L2NormalizationMode mode=L2NormalizationMode::kInstance)
Definition: op.h:7792
Symbol max(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2735
Symbol Correlation(const std::string &symbol_name, Symbol data1, Symbol data2, uint32_t kernel_size=1, uint32_t max_displacement=1, uint32_t stride1=1, uint32_t stride2=1, uint32_t pad_size=0, bool is_multiply=true)
Definition: op.h:7949
Symbol zeros_like(const std::string &symbol_name, Symbol data)
Definition: op.h:977
EmbeddingDtype
Definition: op.h:3257
Symbol broadcast_mod(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1182
Symbol cbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:2263
operator helper functions
Symbol broadcast_logical_and(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3864
Symbol logical_not(const std::string &symbol_name, Symbol data)
Definition: op.h:2481
Symbol tanh(const std::string &symbol_name, Symbol data)
Definition: op.h:4415
Symbol broadcast_to(const std::string &symbol_name, Symbol data, Shape shape=Shape())
Definition: op.h:2853
Symbol elemwise_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3199
Symbol diag(const std::string &symbol_name, Symbol data, dmlc::optional< int > k=dmlc::optional< int >(0))
Definition: op.h:3978
DropoutMode
Definition: op.h:5307
Symbol norm(const std::string &symbol_name, Symbol data, int ord=2, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false)
Definition: op.h:2938
Symbol MakeLoss(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float valid_thresh=0, MakeLossNormalization normalization=MakeLossNormalization::kNull)
Definition: op.h:7852
Symbol log(const std::string &symbol_name, Symbol data)
Definition: op.h:2334
Symbol broadcast_logical_xor(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3926
Symbol sigmoid(const std::string &symbol_name, Symbol data)
Definition: op.h:1605
CastDtype
Definition: op.h:1835
DotForwardStype
Definition: op.h:1421
ConvolutionLayout
Definition: op.h:4941
Symbol LogisticRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:5953
Symbol gamma(const std::string &symbol_name, Symbol data)
Definition: op.h:2445
Symbol sin(const std::string &symbol_name, Symbol data)
Definition: op.h:4155
UpSamplingMultiInputMode
Definition: op.h:5104
Symbol CreateSymbol(const std::string &name="")
create a Symbol from the current operator
Symbol elemwise_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3226
Symbol softmax(const std::string &symbol_name, Symbol data, int axis=-1, dmlc::optional< double > temperature=dmlc::optional< double >())
Definition: op.h:4641
Batch_dotForwardStype
Definition: op.h:1511
SpatialTransformerSamplerType
Definition: op.h:7265
Symbol Pad(const std::string &symbol_name, Symbol data, PadMode mode, Shape pad_width, double constant_value=0)
Definition: op.h:5588
Symbol square(const std::string &symbol_name, Symbol data)
Definition: op.h:2177
Symbol LeakyReLU(const std::string &symbol_name, Symbol data, Symbol gamma, LeakyReLUActType act_type=LeakyReLUActType::kLeaky, mx_float slope=0.25, mx_float lower_bound=0.125, mx_float upper_bound=0.334)
Definition: op.h:5646
One_hotDtype
Definition: op.h:3476
Symbol nansum(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2659
UpSamplingSampleType
Definition: op.h:5096
Symbol rmsprop_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, mx_float lr, mx_float gamma1=0.95, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:6477
Symbol make_loss(const std::string &symbol_name, Symbol data)
Definition: op.h:1729
Symbol SoftmaxActivation(const std::string &symbol_name, Symbol data, SoftmaxActivationMode mode=SoftmaxActivationMode::kInstance)
Definition: op.h:5410
Symbol broadcast_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3678
Symbol nanprod(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2698
Symbol Deconvolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), Shape adj=Shape(), Shape target_shape=Shape(), uint32_t num_group=1, uint64_t workspace=512, bool no_bias=true, DeconvolutionCudnnTune cudnn_tune=DeconvolutionCudnnTune::kNone, bool cudnn_off=false, DeconvolutionLayout layout=DeconvolutionLayout::kNone)
Definition: op.h:4735
Symbol broadcast_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1040
Operator & SetParam(const std::string &name, const T &value)
set config parameters
Definition: operator.h:58
Symbol tan(const std::string &symbol_name, Symbol data)
Definition: op.h:4207
Convolution_v1CudnnTune
Definition: op.h:7039
Symbol repeat(const std::string &symbol_name, Symbol data, int repeats, dmlc::optional< int > axis=dmlc::optional< int >())
Definition: op.h:690
Symbol slice_axis(const std::string &symbol_name, Symbol data, int axis, int begin, dmlc::optional< int > end)
Definition: op.h:526
Symbol expand_dims(const std::string &symbol_name, Symbol data, int axis)
Definition: op.h:417
Symbol arctanh(const std::string &symbol_name, Symbol data)
Definition: op.h:4483
Symbol softmax_cross_entropy(const std::string &symbol_name, Symbol data, Symbol label)
Definition: op.h:5827
Symbol broadcast_axis(const std::string &symbol_name, Symbol data, Shape axis=Shape(), Shape size=Shape())
Definition: op.h:2814
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:1946
Symbol cosh(const std::string &symbol_name, Symbol data)
Definition: op.h:4389
Symbol sort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Definition: op.h:3078
Symbol gather_nd(const std::string &symbol_name, Symbol data, Symbol indices)
Definition: op.h:3585
Symbol slice_like(const std::string &symbol_name, Symbol data, Symbol shape_like, Shape axes=Shape())
Definition: op.h:600
Symbol BilinearSampler(const std::string &symbol_name, Symbol data, Symbol grid)
Definition: op.h:7576
Symbol Custom(const std::string &symbol_name, const std::vector< Symbol > &data, const std::string &op_type)
Definition: op.h:84
Symbol Pooling_v1(const std::string &symbol_name, Symbol data, Shape kernel=Shape(), Pooling_v1PoolType pool_type=Pooling_v1PoolType::kMax, bool global_pool=false, Pooling_v1PoolingConvention pooling_convention=Pooling_v1PoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:6888
Symbol broadcast_hypot(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:219
Symbol BatchNorm_v1(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false)
Definition: op.h:5769
Symbol UpSampling(const std::string &symbol_name, const std::vector< Symbol > &data, uint32_t scale, UpSamplingSampleType sample_type, int num_args, uint32_t num_filter=0, UpSamplingMultiInputMode multi_input_mode=UpSamplingMultiInputMode::kConcat, uint64_t workspace=512)
Definition: op.h:5124
Symbol Activation(const std::string &symbol_name, Symbol data, ActivationActType act_type)
Definition: op.h:4815
float mx_float
manually define float
Definition: c_api.h:60
Symbol SVMOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float margin=1, mx_float regularization_coefficient=1, bool use_linear=false)
Definition: op.h:7885
Symbol radians(const std::string &symbol_name, Symbol data)
Definition: op.h:4341
Symbol Concat(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int dim=1)
Definition: op.h:5197
L2NormalizationMode
Definition: op.h:7724
Symbol SequenceMask(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, mx_float value=0, int axis=0)
Definition: op.h:8048
Symbol stack(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int axis=0)
Definition: op.h:810
Symbol floor(const std::string &symbol_name, Symbol data)
Definition: op.h:2089
Symbol broadcast_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1081
Symbol take(const std::string &symbol_name, Symbol a, Symbol indices, int axis=0, TakeMode mode=TakeMode::kClip)
Definition: op.h:3419
Symbol ceil(const std::string &symbol_name, Symbol data)
Definition: op.h:2060
Symbol gammaln(const std::string &symbol_name, Symbol data)
Definition: op.h:2463
Symbol tile(const std::string &symbol_name, Symbol data, Shape reps)
Definition: op.h:746
Symbol min(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2772
Symbol signum_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float wd_lh=0)
Definition: op.h:6065
Symbol depth_to_space(const std::string &symbol_name, Symbol data, int block_size)
Definition: op.h:893
Symbol rmspropalex_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, Symbol g, Symbol delta, mx_float lr, mx_float gamma1=0.95, mx_float gamma2=0.9, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:6547
Symbol ftml_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol d, Symbol v, Symbol z, mx_float lr, int t, mx_float beta1=0.6, mx_float beta2=0.999, double epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_grad=-1)
Definition: op.h:6315
SoftmaxNormalization
Definition: op.h:7435
DeconvolutionCudnnTune
Definition: op.h:4688
ConvolutionCudnnTune
Definition: op.h:4931
Symbol prod(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2620
Symbol pick(const std::string &symbol_name, Symbol data, Symbol index, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool keepdims=false, PickMode mode=PickMode::kClip)
Definition: op.h:1398
definition of shape
Symbol broadcast_greater(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3740
Symbol BatchNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, Symbol moving_mean, Symbol moving_var, double eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false, int axis=1, bool cudnn_off=false)
Definition: op.h:4900
Symbol rcbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:2287
PickMode
Definition: op.h:1337
Symbol signsgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:6014
Symbol broadcast_power(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:114
SoftmaxOutputNormalization
Definition: op.h:7302
Symbol SoftmaxOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxOutputNormalization normalization=SoftmaxOutputNormalization::kNull, bool out_grad=false, mx_float smooth_alpha=0)
Definition: op.h:7403
Symbol Flatten(const std::string &symbol_name, Symbol data)
Definition: op.h:350
Symbol BlockGrad(const std::string &symbol_name, Symbol data)
Definition: op.h:1693
LeakyReLUActType
Definition: op.h:5608
Symbol sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, bool lazy_update=true)
Definition: op.h:6179
Symbol arccos(const std::string &symbol_name, Symbol data)
Definition: op.h:4261
Symbol hard_sigmoid(const std::string &symbol_name, Symbol data, mx_float alpha=0.2, mx_float beta=0.5)
Definition: op.h:1627
Symbol argmax_channel(const std::string &symbol_name, Symbol data)
Definition: op.h:1326
Symbol batch_take(const std::string &symbol_name, Symbol a, Symbol indices)
Definition: op.h:3465
Symbol mean(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2583
Symbol LinearRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:5865
Symbol softsign(const std::string &symbol_name, Symbol data)
Definition: op.h:1653
Symbol broadcast_like(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2886
Symbol sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, bool lazy_update=true)
Definition: op.h:6116
Symbol choose_element_0index(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:8071
Symbol Reshape(const std::string &symbol_name, Symbol data, Shape shape=Shape(), bool reverse=false, Shape target_shape=Shape(), bool keep_highest=false)
Definition: op.h:302
Symbol degrees(const std::string &symbol_name, Symbol data)
Definition: op.h:4315
Symbol shape_array(const std::string &symbol_name, Symbol data, dmlc::optional< int > lhs_begin=dmlc::optional< int >(), dmlc::optional< int > lhs_end=dmlc::optional< int >(), dmlc::optional< int > rhs_begin=dmlc::optional< int >(), dmlc::optional< int > rhs_end=dmlc::optional< int >())
Definition: op.h:1797
Symbol one_hot(const std::string &symbol_name, Symbol indices, int depth, double on_value=1, double off_value=0, One_hotDtype dtype=One_hotDtype::kFloat32)
Definition: op.h:3530
Symbol negative(const std::string &symbol_name, Symbol data)
Definition: op.h:1896
Symbol sum(const std::string &symbol_name, Symbol data, dmlc::optional< Shape > axis=dmlc::optional< Shape >(), bool keepdims=false, bool exclude=false)
Definition: op.h:2546
Symbol GridGenerator(const std::string &symbol_name, Symbol data, GridGeneratorTransformType transform_type, Shape target_shape=Shape(0, 0))
Definition: op.h:6807
Symbol size_array(const std::string &symbol_name, Symbol data)
Definition: op.h:1826
Symbol argmax(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1251
Operator interface.
Definition: operator.h:43
Symbol interface.
Definition: symbol.h:72
MakeLossNormalization
Definition: op.h:7812
Symbol log10(const std::string &symbol_name, Symbol data)
Definition: op.h:2355
Symbol log2(const std::string &symbol_name, Symbol data)
Definition: op.h:2376