mxnet
op.h
Go to the documentation of this file.
1 
8 #ifndef MXNET_CPP_OP_H_
9 #define MXNET_CPP_OP_H_
10 
11 #include <string>
12 #include <vector>
13 #include "mxnet-cpp/base.h"
14 #include "mxnet-cpp/shape.h"
15 #include "mxnet-cpp/op_util.h"
16 #include "mxnet-cpp/operator.h"
17 #include "dmlc/optional.h"
18 #include "nnvm/tuple.h"
19 
20 namespace mxnet {
21 namespace cpp {
22 
63 inline Symbol khatri_rao(const std::string& symbol_name,
64  const std::vector<Symbol>& args) {
65  return Operator("khatri_rao")
66 (args)
67  .CreateSymbol(symbol_name);
68 }
69 
80 inline Symbol all_finite(const std::string& symbol_name,
81  Symbol data,
82  bool init_output = true) {
83  return Operator("all_finite")
84  .SetParam("init_output", init_output)
85  .SetInput("data", data)
86  .CreateSymbol(symbol_name);
87 }
88 
100 inline Symbol multi_all_finite(const std::string& symbol_name,
101  const std::vector<Symbol>& data,
102  int num_arrays = 1,
103  bool init_output = true) {
104  return Operator("multi_all_finite")
105  .SetParam("num_arrays", num_arrays)
106  .SetParam("init_output", init_output)
107 (data)
108  .CreateSymbol(symbol_name);
109 }
110 
126 inline Symbol Custom(const std::string& symbol_name,
127  const std::vector<Symbol>& data,
128  const std::string& op_type) {
129  return Operator("Custom")
130 (data)
131  .CreateSymbol(symbol_name);
132 }
133 
156 inline Symbol broadcast_power(const std::string& symbol_name,
157  Symbol lhs,
158  Symbol rhs) {
159  return Operator("broadcast_power")
160  .SetInput("lhs", lhs)
161  .SetInput("rhs", rhs)
162  .CreateSymbol(symbol_name);
163 }
164 
189 inline Symbol broadcast_maximum(const std::string& symbol_name,
190  Symbol lhs,
191  Symbol rhs) {
192  return Operator("broadcast_maximum")
193  .SetInput("lhs", lhs)
194  .SetInput("rhs", rhs)
195  .CreateSymbol(symbol_name);
196 }
197 
222 inline Symbol broadcast_minimum(const std::string& symbol_name,
223  Symbol lhs,
224  Symbol rhs) {
225  return Operator("broadcast_minimum")
226  .SetInput("lhs", lhs)
227  .SetInput("rhs", rhs)
228  .CreateSymbol(symbol_name);
229 }
230 
261 inline Symbol broadcast_hypot(const std::string& symbol_name,
262  Symbol lhs,
263  Symbol rhs) {
264  return Operator("broadcast_hypot")
265  .SetInput("lhs", lhs)
266  .SetInput("rhs", rhs)
267  .CreateSymbol(symbol_name);
268 }
269 
344 inline Symbol Reshape(const std::string& symbol_name,
345  Symbol data,
346  Shape shape = Shape(),
347  bool reverse = false,
348  Shape target_shape = Shape(),
349  bool keep_highest = false) {
350  return Operator("Reshape")
351  .SetParam("shape", shape)
352  .SetParam("reverse", reverse)
353  .SetParam("target_shape", target_shape)
354  .SetParam("keep_highest", keep_highest)
355  .SetInput("data", data)
356  .CreateSymbol(symbol_name);
357 }
358 
392 inline Symbol Flatten(const std::string& symbol_name,
393  Symbol data) {
394  return Operator("Flatten")
395  .SetInput("data", data)
396  .CreateSymbol(symbol_name);
397 }
398 
435 inline Symbol transpose(const std::string& symbol_name,
436  Symbol data,
437  Shape axes = Shape()) {
438  return Operator("transpose")
439  .SetParam("axes", axes)
440  .SetInput("data", data)
441  .CreateSymbol(symbol_name);
442 }
443 
459 inline Symbol expand_dims(const std::string& symbol_name,
460  Symbol data,
461  int axis) {
462  return Operator("expand_dims")
463  .SetParam("axis", axis)
464  .SetInput("data", data)
465  .CreateSymbol(symbol_name);
466 }
467 
523 inline Symbol slice(const std::string& symbol_name,
524  Symbol data,
525  Shape begin,
526  Shape end,
527  Shape step = Shape()) {
528  return Operator("slice")
529  .SetParam("begin", begin)
530  .SetParam("end", end)
531  .SetParam("step", step)
532  .SetInput("data", data)
533  .CreateSymbol(symbol_name);
534 }
535 
568 inline Symbol slice_axis(const std::string& symbol_name,
569  Symbol data,
570  int axis,
571  int begin,
572  dmlc::optional<int> end) {
573  return Operator("slice_axis")
574  .SetParam("axis", axis)
575  .SetParam("begin", begin)
576  .SetParam("end", end)
577  .SetInput("data", data)
578  .CreateSymbol(symbol_name);
579 }
580 
642 inline Symbol slice_like(const std::string& symbol_name,
643  Symbol data,
644  Symbol shape_like,
645  Shape axes = Shape()) {
646  return Operator("slice_like")
647  .SetParam("axes", axes)
648  .SetInput("data", data)
649  .SetInput("shape_like", shape_like)
650  .CreateSymbol(symbol_name);
651 }
652 
687 inline Symbol clip(const std::string& symbol_name,
688  Symbol data,
689  mx_float a_min,
690  mx_float a_max) {
691  return Operator("clip")
692  .SetParam("a_min", a_min)
693  .SetParam("a_max", a_max)
694  .SetInput("data", data)
695  .CreateSymbol(symbol_name);
696 }
697 
732 inline Symbol repeat(const std::string& symbol_name,
733  Symbol data,
734  int repeats,
735  dmlc::optional<int> axis = dmlc::optional<int>()) {
736  return Operator("repeat")
737  .SetParam("repeats", repeats)
738  .SetParam("axis", axis)
739  .SetInput("data", data)
740  .CreateSymbol(symbol_name);
741 }
742 
788 inline Symbol tile(const std::string& symbol_name,
789  Symbol data,
790  Shape reps) {
791  return Operator("tile")
792  .SetParam("reps", reps)
793  .SetInput("data", data)
794  .CreateSymbol(symbol_name);
795 }
796 
820 inline Symbol reverse(const std::string& symbol_name,
821  Symbol data,
822  Shape axis) {
823  return Operator("reverse")
824  .SetParam("axis", axis)
825  .SetInput("data", data)
826  .CreateSymbol(symbol_name);
827 }
828 
852 inline Symbol stack(const std::string& symbol_name,
853  const std::vector<Symbol>& data,
854  int num_args,
855  int axis = 0) {
856  return Operator("stack")
857  .SetParam("num_args", num_args)
858  .SetParam("axis", axis)
859 (data)
860  .CreateSymbol(symbol_name);
861 }
862 
885 inline Symbol squeeze(const std::string& symbol_name,
886  const std::vector<Symbol>& data,
887  dmlc::optional<Shape> axis = dmlc::optional<Shape>()) {
888  return Operator("squeeze")
889  .SetParam("axis", axis)
890 (data)
891  .CreateSymbol(symbol_name);
892 }
893 
935 inline Symbol depth_to_space(const std::string& symbol_name,
936  Symbol data,
937  int block_size) {
938  return Operator("depth_to_space")
939  .SetParam("block_size", block_size)
940  .SetInput("data", data)
941  .CreateSymbol(symbol_name);
942 }
943 
987 inline Symbol space_to_depth(const std::string& symbol_name,
988  Symbol data,
989  int block_size) {
990  return Operator("space_to_depth")
991  .SetParam("block_size", block_size)
992  .SetInput("data", data)
993  .CreateSymbol(symbol_name);
994 }
995 
1019 inline Symbol zeros_like(const std::string& symbol_name,
1020  Symbol data) {
1021  return Operator("zeros_like")
1022  .SetInput("data", data)
1023  .CreateSymbol(symbol_name);
1024 }
1025 
1043 inline Symbol ones_like(const std::string& symbol_name,
1044  Symbol data) {
1045  return Operator("ones_like")
1046  .SetInput("data", data)
1047  .CreateSymbol(symbol_name);
1048 }
1049 
1072 inline Symbol add_n(const std::string& symbol_name,
1073  const std::vector<Symbol>& args) {
1074  return Operator("add_n")
1075 (args)
1076  .CreateSymbol(symbol_name);
1077 }
1078 
1110 inline Symbol argmax(const std::string& symbol_name,
1111  Symbol data,
1112  dmlc::optional<int> axis = dmlc::optional<int>(),
1113  bool keepdims = false) {
1114  return Operator("argmax")
1115  .SetParam("axis", axis)
1116  .SetParam("keepdims", keepdims)
1117  .SetInput("data", data)
1118  .CreateSymbol(symbol_name);
1119 }
1120 
1152 inline Symbol argmin(const std::string& symbol_name,
1153  Symbol data,
1154  dmlc::optional<int> axis = dmlc::optional<int>(),
1155  bool keepdims = false) {
1156  return Operator("argmin")
1157  .SetParam("axis", axis)
1158  .SetParam("keepdims", keepdims)
1159  .SetInput("data", data)
1160  .CreateSymbol(symbol_name);
1161 }
1162 
1185 inline Symbol argmax_channel(const std::string& symbol_name,
1186  Symbol data) {
1187  return Operator("argmax_channel")
1188  .SetInput("data", data)
1189  .CreateSymbol(symbol_name);
1190 }
1191 
1196 enum class PickMode {
1197  kClip = 0,
1198  kWrap = 1
1199 };
1200 
1257 inline Symbol pick(const std::string& symbol_name,
1258  Symbol data,
1259  Symbol index,
1260  dmlc::optional<int> axis = dmlc::optional<int>(-1),
1261  bool keepdims = false,
1262  PickMode mode = PickMode::kClip) {
1263  static const char *PickModeValues[] = {
1264  "clip",
1265  "wrap"
1266  };
1267  return Operator("pick")
1268  .SetParam("axis", axis)
1269  .SetParam("keepdims", keepdims)
1270  .SetParam("mode", PickModeValues[int(mode)])
1271  .SetInput("data", data)
1272  .SetInput("index", index)
1273  .CreateSymbol(symbol_name);
1274 }
1275 
1280 enum class DotForwardStype {
1281  kNone = 0,
1282  kCsr = 1,
1283  kDefault = 2,
1284  kRow_sparse = 3
1285 };
1286 
1345 inline Symbol dot(const std::string& symbol_name,
1346  Symbol lhs,
1347  Symbol rhs,
1348  bool transpose_a = false,
1349  bool transpose_b = false,
1350  DotForwardStype forward_stype = DotForwardStype::kNone) {
1351  static const char *DotForwardStypeValues[] = {
1352  "None",
1353  "csr",
1354  "default",
1355  "row_sparse"
1356  };
1357  return Operator("dot")
1358  .SetParam("transpose_a", transpose_a)
1359  .SetParam("transpose_b", transpose_b)
1360  .SetParam("forward_stype", DotForwardStypeValues[int(forward_stype)])
1361  .SetInput("lhs", lhs)
1362  .SetInput("rhs", rhs)
1363  .CreateSymbol(symbol_name);
1364 }
1365 
1371  kNone = 0,
1372  kCsr = 1,
1373  kDefault = 2,
1374  kRow_sparse = 3
1375 };
1376 
1402 inline Symbol batch_dot(const std::string& symbol_name,
1403  Symbol lhs,
1404  Symbol rhs,
1405  bool transpose_a = false,
1406  bool transpose_b = false,
1408  static const char *Batch_dotForwardStypeValues[] = {
1409  "None",
1410  "csr",
1411  "default",
1412  "row_sparse"
1413  };
1414  return Operator("batch_dot")
1415  .SetParam("transpose_a", transpose_a)
1416  .SetParam("transpose_b", transpose_b)
1417  .SetParam("forward_stype", Batch_dotForwardStypeValues[int(forward_stype)])
1418  .SetInput("lhs", lhs)
1419  .SetInput("rhs", rhs)
1420  .CreateSymbol(symbol_name);
1421 }
1422 
1455 inline Symbol broadcast_add(const std::string& symbol_name,
1456  Symbol lhs,
1457  Symbol rhs) {
1458  return Operator("broadcast_add")
1459  .SetInput("lhs", lhs)
1460  .SetInput("rhs", rhs)
1461  .CreateSymbol(symbol_name);
1462 }
1463 
1496 inline Symbol broadcast_sub(const std::string& symbol_name,
1497  Symbol lhs,
1498  Symbol rhs) {
1499  return Operator("broadcast_sub")
1500  .SetInput("lhs", lhs)
1501  .SetInput("rhs", rhs)
1502  .CreateSymbol(symbol_name);
1503 }
1504 
1531 inline Symbol broadcast_mul(const std::string& symbol_name,
1532  Symbol lhs,
1533  Symbol rhs) {
1534  return Operator("broadcast_mul")
1535  .SetInput("lhs", lhs)
1536  .SetInput("rhs", rhs)
1537  .CreateSymbol(symbol_name);
1538 }
1539 
1566 inline Symbol broadcast_div(const std::string& symbol_name,
1567  Symbol lhs,
1568  Symbol rhs) {
1569  return Operator("broadcast_div")
1570  .SetInput("lhs", lhs)
1571  .SetInput("rhs", rhs)
1572  .CreateSymbol(symbol_name);
1573 }
1574 
1597 inline Symbol broadcast_mod(const std::string& symbol_name,
1598  Symbol lhs,
1599  Symbol rhs) {
1600  return Operator("broadcast_mod")
1601  .SetInput("lhs", lhs)
1602  .SetInput("rhs", rhs)
1603  .CreateSymbol(symbol_name);
1604 }
1605 
1625 inline Symbol relu(const std::string& symbol_name,
1626  Symbol data) {
1627  return Operator("relu")
1628  .SetInput("data", data)
1629  .CreateSymbol(symbol_name);
1630 }
1631 
1647 inline Symbol sigmoid(const std::string& symbol_name,
1648  Symbol data) {
1649  return Operator("sigmoid")
1650  .SetInput("data", data)
1651  .CreateSymbol(symbol_name);
1652 }
1653 
1669 inline Symbol hard_sigmoid(const std::string& symbol_name,
1670  Symbol data,
1671  mx_float alpha = 0.200000003,
1672  mx_float beta = 0.5) {
1673  return Operator("hard_sigmoid")
1674  .SetParam("alpha", alpha)
1675  .SetParam("beta", beta)
1676  .SetInput("data", data)
1677  .CreateSymbol(symbol_name);
1678 }
1679 
1695 inline Symbol softsign(const std::string& symbol_name,
1696  Symbol data) {
1697  return Operator("softsign")
1698  .SetInput("data", data)
1699  .CreateSymbol(symbol_name);
1700 }
1701 
1735 inline Symbol BlockGrad(const std::string& symbol_name,
1736  Symbol data) {
1737  return Operator("BlockGrad")
1738  .SetInput("data", data)
1739  .CreateSymbol(symbol_name);
1740 }
1741 
1771 inline Symbol make_loss(const std::string& symbol_name,
1772  Symbol data) {
1773  return Operator("make_loss")
1774  .SetInput("data", data)
1775  .CreateSymbol(symbol_name);
1776 }
1777 
1812 inline Symbol reshape_like(const std::string& symbol_name,
1813  Symbol lhs,
1814  Symbol rhs) {
1815  return Operator("reshape_like")
1816  .SetInput("lhs", lhs)
1817  .SetInput("rhs", rhs)
1818  .CreateSymbol(symbol_name);
1819 }
1820 
1839 inline Symbol shape_array(const std::string& symbol_name,
1840  Symbol data,
1841  dmlc::optional<int> lhs_begin = dmlc::optional<int>(),
1842  dmlc::optional<int> lhs_end = dmlc::optional<int>(),
1843  dmlc::optional<int> rhs_begin = dmlc::optional<int>(),
1844  dmlc::optional<int> rhs_end = dmlc::optional<int>()) {
1845  return Operator("shape_array")
1846  .SetParam("lhs_begin", lhs_begin)
1847  .SetParam("lhs_end", lhs_end)
1848  .SetParam("rhs_begin", rhs_begin)
1849  .SetParam("rhs_end", rhs_end)
1850  .SetInput("data", data)
1851  .CreateSymbol(symbol_name);
1852 }
1853 
1868 inline Symbol size_array(const std::string& symbol_name,
1869  Symbol data) {
1870  return Operator("size_array")
1871  .SetInput("data", data)
1872  .CreateSymbol(symbol_name);
1873 }
1874 
1877 enum class CastDtype {
1878  kFloat16 = 0,
1879  kFloat32 = 1,
1880  kFloat64 = 2,
1881  kInt32 = 3,
1882  kInt64 = 4,
1883  kInt8 = 5,
1884  kUint8 = 6
1885 };
1886 
1906 inline Symbol Cast(const std::string& symbol_name,
1907  Symbol data,
1908  CastDtype dtype) {
1909  static const char *CastDtypeValues[] = {
1910  "float16",
1911  "float32",
1912  "float64",
1913  "int32",
1914  "int64",
1915  "int8",
1916  "uint8"
1917  };
1918  return Operator("Cast")
1919  .SetParam("dtype", CastDtypeValues[int(dtype)])
1920  .SetInput("data", data)
1921  .CreateSymbol(symbol_name);
1922 }
1923 
1938 inline Symbol negative(const std::string& symbol_name,
1939  Symbol data) {
1940  return Operator("negative")
1941  .SetInput("data", data)
1942  .CreateSymbol(symbol_name);
1943 }
1944 
1961 inline Symbol reciprocal(const std::string& symbol_name,
1962  Symbol data) {
1963  return Operator("reciprocal")
1964  .SetInput("data", data)
1965  .CreateSymbol(symbol_name);
1966 }
1967 
1988 inline Symbol abs(const std::string& symbol_name,
1989  Symbol data) {
1990  return Operator("abs")
1991  .SetInput("data", data)
1992  .CreateSymbol(symbol_name);
1993 }
1994 
2015 inline Symbol sign(const std::string& symbol_name,
2016  Symbol data) {
2017  return Operator("sign")
2018  .SetInput("data", data)
2019  .CreateSymbol(symbol_name);
2020 }
2021 
2042 inline Symbol round(const std::string& symbol_name,
2043  Symbol data) {
2044  return Operator("round")
2045  .SetInput("data", data)
2046  .CreateSymbol(symbol_name);
2047 }
2048 
2073 inline Symbol rint(const std::string& symbol_name,
2074  Symbol data) {
2075  return Operator("rint")
2076  .SetInput("data", data)
2077  .CreateSymbol(symbol_name);
2078 }
2079 
2102 inline Symbol ceil(const std::string& symbol_name,
2103  Symbol data) {
2104  return Operator("ceil")
2105  .SetInput("data", data)
2106  .CreateSymbol(symbol_name);
2107 }
2108 
2131 inline Symbol floor(const std::string& symbol_name,
2132  Symbol data) {
2133  return Operator("floor")
2134  .SetInput("data", data)
2135  .CreateSymbol(symbol_name);
2136 }
2137 
2161 inline Symbol trunc(const std::string& symbol_name,
2162  Symbol data) {
2163  return Operator("trunc")
2164  .SetInput("data", data)
2165  .CreateSymbol(symbol_name);
2166 }
2167 
2189 inline Symbol fix(const std::string& symbol_name,
2190  Symbol data) {
2191  return Operator("fix")
2192  .SetInput("data", data)
2193  .CreateSymbol(symbol_name);
2194 }
2195 
2219 inline Symbol square(const std::string& symbol_name,
2220  Symbol data) {
2221  return Operator("square")
2222  .SetInput("data", data)
2223  .CreateSymbol(symbol_name);
2224 }
2225 
2249 inline Symbol sqrt(const std::string& symbol_name,
2250  Symbol data) {
2251  return Operator("sqrt")
2252  .SetInput("data", data)
2253  .CreateSymbol(symbol_name);
2254 }
2255 
2275 inline Symbol rsqrt(const std::string& symbol_name,
2276  Symbol data) {
2277  return Operator("rsqrt")
2278  .SetInput("data", data)
2279  .CreateSymbol(symbol_name);
2280 }
2281 
2305 inline Symbol cbrt(const std::string& symbol_name,
2306  Symbol data) {
2307  return Operator("cbrt")
2308  .SetInput("data", data)
2309  .CreateSymbol(symbol_name);
2310 }
2311 
2326 inline Symbol erf(const std::string& symbol_name,
2327  Symbol data) {
2328  return Operator("erf")
2329  .SetInput("data", data)
2330  .CreateSymbol(symbol_name);
2331 }
2332 
2347 inline Symbol erfinv(const std::string& symbol_name,
2348  Symbol data) {
2349  return Operator("erfinv")
2350  .SetInput("data", data)
2351  .CreateSymbol(symbol_name);
2352 }
2353 
2371 inline Symbol rcbrt(const std::string& symbol_name,
2372  Symbol data) {
2373  return Operator("rcbrt")
2374  .SetInput("data", data)
2375  .CreateSymbol(symbol_name);
2376 }
2377 
2397 inline Symbol exp(const std::string& symbol_name,
2398  Symbol data) {
2399  return Operator("exp")
2400  .SetInput("data", data)
2401  .CreateSymbol(symbol_name);
2402 }
2403 
2418 inline Symbol log(const std::string& symbol_name,
2419  Symbol data) {
2420  return Operator("log")
2421  .SetInput("data", data)
2422  .CreateSymbol(symbol_name);
2423 }
2424 
2439 inline Symbol log10(const std::string& symbol_name,
2440  Symbol data) {
2441  return Operator("log10")
2442  .SetInput("data", data)
2443  .CreateSymbol(symbol_name);
2444 }
2445 
2460 inline Symbol log2(const std::string& symbol_name,
2461  Symbol data) {
2462  return Operator("log2")
2463  .SetInput("data", data)
2464  .CreateSymbol(symbol_name);
2465 }
2466 
2486 inline Symbol log1p(const std::string& symbol_name,
2487  Symbol data) {
2488  return Operator("log1p")
2489  .SetInput("data", data)
2490  .CreateSymbol(symbol_name);
2491 }
2492 
2511 inline Symbol expm1(const std::string& symbol_name,
2512  Symbol data) {
2513  return Operator("expm1")
2514  .SetInput("data", data)
2515  .CreateSymbol(symbol_name);
2516 }
2517 
2529 inline Symbol gamma(const std::string& symbol_name,
2530  Symbol data) {
2531  return Operator("gamma")
2532  .SetInput("data", data)
2533  .CreateSymbol(symbol_name);
2534 }
2535 
2547 inline Symbol gammaln(const std::string& symbol_name,
2548  Symbol data) {
2549  return Operator("gammaln")
2550  .SetInput("data", data)
2551  .CreateSymbol(symbol_name);
2552 }
2553 
2565 inline Symbol logical_not(const std::string& symbol_name,
2566  Symbol data) {
2567  return Operator("logical_not")
2568  .SetInput("data", data)
2569  .CreateSymbol(symbol_name);
2570 }
2571 
2574 enum class Amp_castDtype {
2575  kFloat16 = 0,
2576  kFloat32 = 1,
2577  kFloat64 = 2,
2578  kInt32 = 3,
2579  kInt64 = 4,
2580  kInt8 = 5,
2581  kUint8 = 6
2582 };
2583 
2596 inline Symbol amp_cast(const std::string& symbol_name,
2597  Symbol data,
2598  Amp_castDtype dtype) {
2599  static const char *Amp_castDtypeValues[] = {
2600  "float16",
2601  "float32",
2602  "float64",
2603  "int32",
2604  "int64",
2605  "int8",
2606  "uint8"
2607  };
2608  return Operator("amp_cast")
2609  .SetParam("dtype", Amp_castDtypeValues[int(dtype)])
2610  .SetInput("data", data)
2611  .CreateSymbol(symbol_name);
2612 }
2613 
2627 inline Symbol amp_multicast(const std::string& symbol_name,
2628  const std::vector<Symbol>& data,
2629  int num_outputs) {
2630  return Operator("amp_multicast")
2631  .SetParam("num_outputs", num_outputs)
2632 (data)
2633  .CreateSymbol(symbol_name);
2634 }
2635 
2641 enum class TopkRetTyp {
2642  kBoth = 0,
2643  kIndices = 1,
2644  kMask = 2,
2645  kValue = 3
2646 };
2647 
2650 enum class TopkDtype {
2651  kFloat16 = 0,
2652  kFloat32 = 1,
2653  kFloat64 = 2,
2654  kInt32 = 3,
2655  kUint8 = 4
2656 };
2657 
2701 inline Symbol topk(const std::string& symbol_name,
2702  Symbol data,
2703  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2704  int k = 1,
2705  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
2706  bool is_ascend = false,
2707  TopkDtype dtype = TopkDtype::kFloat32) {
2708  static const char *TopkRetTypValues[] = {
2709  "both",
2710  "indices",
2711  "mask",
2712  "value"
2713  };
2714  static const char *TopkDtypeValues[] = {
2715  "float16",
2716  "float32",
2717  "float64",
2718  "int32",
2719  "uint8"
2720  };
2721  return Operator("topk")
2722  .SetParam("axis", axis)
2723  .SetParam("k", k)
2724  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
2725  .SetParam("is_ascend", is_ascend)
2726  .SetParam("dtype", TopkDtypeValues[int(dtype)])
2727  .SetInput("data", data)
2728  .CreateSymbol(symbol_name);
2729 }
2730 
2763 inline Symbol sort(const std::string& symbol_name,
2764  Symbol data,
2765  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2766  bool is_ascend = true) {
2767  return Operator("sort")
2768  .SetParam("axis", axis)
2769  .SetParam("is_ascend", is_ascend)
2770  .SetInput("data", data)
2771  .CreateSymbol(symbol_name);
2772 }
2773 
2777 enum class ArgsortDtype {
2778  kFloat16 = 0,
2779  kFloat32 = 1,
2780  kFloat64 = 2,
2781  kInt32 = 3,
2782  kUint8 = 4
2783 };
2784 
2817 inline Symbol argsort(const std::string& symbol_name,
2818  Symbol data,
2819  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2820  bool is_ascend = true,
2822  static const char *ArgsortDtypeValues[] = {
2823  "float16",
2824  "float32",
2825  "float64",
2826  "int32",
2827  "uint8"
2828  };
2829  return Operator("argsort")
2830  .SetParam("axis", axis)
2831  .SetParam("is_ascend", is_ascend)
2832  .SetParam("dtype", ArgsortDtypeValues[int(dtype)])
2833  .SetInput("data", data)
2834  .CreateSymbol(symbol_name);
2835 }
2836 
2856 inline Symbol elemwise_add(const std::string& symbol_name,
2857  Symbol lhs,
2858  Symbol rhs) {
2859  return Operator("elemwise_add")
2860  .SetInput("lhs", lhs)
2861  .SetInput("rhs", rhs)
2862  .CreateSymbol(symbol_name);
2863 }
2864 
2884 inline Symbol elemwise_sub(const std::string& symbol_name,
2885  Symbol lhs,
2886  Symbol rhs) {
2887  return Operator("elemwise_sub")
2888  .SetInput("lhs", lhs)
2889  .SetInput("rhs", rhs)
2890  .CreateSymbol(symbol_name);
2891 }
2892 
2911 inline Symbol elemwise_mul(const std::string& symbol_name,
2912  Symbol lhs,
2913  Symbol rhs) {
2914  return Operator("elemwise_mul")
2915  .SetInput("lhs", lhs)
2916  .SetInput("rhs", rhs)
2917  .CreateSymbol(symbol_name);
2918 }
2919 
2931 inline Symbol elemwise_div(const std::string& symbol_name,
2932  Symbol lhs,
2933  Symbol rhs) {
2934  return Operator("elemwise_div")
2935  .SetInput("lhs", lhs)
2936  .SetInput("rhs", rhs)
2937  .CreateSymbol(symbol_name);
2938 }
2939 
2942 enum class EmbeddingDtype {
2943  kFloat16 = 0,
2944  kFloat32 = 1,
2945  kFloat64 = 2,
2946  kInt32 = 3,
2947  kInt64 = 4,
2948  kInt8 = 5,
2949  kUint8 = 6
2950 };
2951 
3015 inline Symbol Embedding(const std::string& symbol_name,
3016  Symbol data,
3017  Symbol weight,
3018  int input_dim,
3019  int output_dim,
3021  bool sparse_grad = false) {
3022  static const char *EmbeddingDtypeValues[] = {
3023  "float16",
3024  "float32",
3025  "float64",
3026  "int32",
3027  "int64",
3028  "int8",
3029  "uint8"
3030  };
3031  return Operator("Embedding")
3032  .SetParam("input_dim", input_dim)
3033  .SetParam("output_dim", output_dim)
3034  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
3035  .SetParam("sparse_grad", sparse_grad)
3036  .SetInput("data", data)
3037  .SetInput("weight", weight)
3038  .CreateSymbol(symbol_name);
3039 }
3040 
3045 enum class TakeMode {
3046  kClip = 0,
3047  kRaise = 1,
3048  kWrap = 2
3049 };
3050 
3113 inline Symbol take(const std::string& symbol_name,
3114  Symbol a,
3115  Symbol indices,
3116  int axis = 0,
3117  TakeMode mode = TakeMode::kClip) {
3118  static const char *TakeModeValues[] = {
3119  "clip",
3120  "raise",
3121  "wrap"
3122  };
3123  return Operator("take")
3124  .SetParam("axis", axis)
3125  .SetParam("mode", TakeModeValues[int(mode)])
3126  .SetInput("a", a)
3127  .SetInput("indices", indices)
3128  .CreateSymbol(symbol_name);
3129 }
3130 
3159 inline Symbol batch_take(const std::string& symbol_name,
3160  Symbol a,
3161  Symbol indices) {
3162  return Operator("batch_take")
3163  .SetInput("a", a)
3164  .SetInput("indices", indices)
3165  .CreateSymbol(symbol_name);
3166 }
3167 
3170 enum class One_hotDtype {
3171  kFloat16 = 0,
3172  kFloat32 = 1,
3173  kFloat64 = 2,
3174  kInt32 = 3,
3175  kInt64 = 4,
3176  kInt8 = 5,
3177  kUint8 = 6
3178 };
3179 
3224 inline Symbol one_hot(const std::string& symbol_name,
3225  Symbol indices,
3226  int depth,
3227  double on_value = 1,
3228  double off_value = 0,
3230  static const char *One_hotDtypeValues[] = {
3231  "float16",
3232  "float32",
3233  "float64",
3234  "int32",
3235  "int64",
3236  "int8",
3237  "uint8"
3238  };
3239  return Operator("one_hot")
3240  .SetParam("depth", depth)
3241  .SetParam("on_value", on_value)
3242  .SetParam("off_value", off_value)
3243  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
3244  .SetInput("indices", indices)
3245  .CreateSymbol(symbol_name);
3246 }
3247 
3279 inline Symbol gather_nd(const std::string& symbol_name,
3280  Symbol data,
3281  Symbol indices) {
3282  return Operator("gather_nd")
3283  .SetInput("data", data)
3284  .SetInput("indices", indices)
3285  .CreateSymbol(symbol_name);
3286 }
3287 
3339 inline Symbol scatter_nd(const std::string& symbol_name,
3340  Symbol data,
3341  Symbol indices,
3342  Shape shape) {
3343  return Operator("scatter_nd")
3344  .SetParam("shape", shape)
3345  .SetInput("data", data)
3346  .SetInput("indices", indices)
3347  .CreateSymbol(symbol_name);
3348 }
3349 
3372 inline Symbol broadcast_equal(const std::string& symbol_name,
3373  Symbol lhs,
3374  Symbol rhs) {
3375  return Operator("broadcast_equal")
3376  .SetInput("lhs", lhs)
3377  .SetInput("rhs", rhs)
3378  .CreateSymbol(symbol_name);
3379 }
3380 
3403 inline Symbol broadcast_not_equal(const std::string& symbol_name,
3404  Symbol lhs,
3405  Symbol rhs) {
3406  return Operator("broadcast_not_equal")
3407  .SetInput("lhs", lhs)
3408  .SetInput("rhs", rhs)
3409  .CreateSymbol(symbol_name);
3410 }
3411 
3434 inline Symbol broadcast_greater(const std::string& symbol_name,
3435  Symbol lhs,
3436  Symbol rhs) {
3437  return Operator("broadcast_greater")
3438  .SetInput("lhs", lhs)
3439  .SetInput("rhs", rhs)
3440  .CreateSymbol(symbol_name);
3441 }
3442 
3465 inline Symbol broadcast_greater_equal(const std::string& symbol_name,
3466  Symbol lhs,
3467  Symbol rhs) {
3468  return Operator("broadcast_greater_equal")
3469  .SetInput("lhs", lhs)
3470  .SetInput("rhs", rhs)
3471  .CreateSymbol(symbol_name);
3472 }
3473 
3496 inline Symbol broadcast_lesser(const std::string& symbol_name,
3497  Symbol lhs,
3498  Symbol rhs) {
3499  return Operator("broadcast_lesser")
3500  .SetInput("lhs", lhs)
3501  .SetInput("rhs", rhs)
3502  .CreateSymbol(symbol_name);
3503 }
3504 
3527 inline Symbol broadcast_lesser_equal(const std::string& symbol_name,
3528  Symbol lhs,
3529  Symbol rhs) {
3530  return Operator("broadcast_lesser_equal")
3531  .SetInput("lhs", lhs)
3532  .SetInput("rhs", rhs)
3533  .CreateSymbol(symbol_name);
3534 }
3535 
3558 inline Symbol broadcast_logical_and(const std::string& symbol_name,
3559  Symbol lhs,
3560  Symbol rhs) {
3561  return Operator("broadcast_logical_and")
3562  .SetInput("lhs", lhs)
3563  .SetInput("rhs", rhs)
3564  .CreateSymbol(symbol_name);
3565 }
3566 
3589 inline Symbol broadcast_logical_or(const std::string& symbol_name,
3590  Symbol lhs,
3591  Symbol rhs) {
3592  return Operator("broadcast_logical_or")
3593  .SetInput("lhs", lhs)
3594  .SetInput("rhs", rhs)
3595  .CreateSymbol(symbol_name);
3596 }
3597 
3620 inline Symbol broadcast_logical_xor(const std::string& symbol_name,
3621  Symbol lhs,
3622  Symbol rhs) {
3623  return Operator("broadcast_logical_xor")
3624  .SetInput("lhs", lhs)
3625  .SetInput("rhs", rhs)
3626  .CreateSymbol(symbol_name);
3627 }
3628 
3693 inline Symbol diag(const std::string& symbol_name,
3694  Symbol data,
3695  int k = 0,
3696  int axis1 = 0,
3697  int axis2 = 1) {
3698  return Operator("diag")
3699  .SetParam("k", k)
3700  .SetParam("axis1", axis1)
3701  .SetParam("axis2", axis2)
3702  .SetInput("data", data)
3703  .CreateSymbol(symbol_name);
3704 }
3705 
3741 inline Symbol where(const std::string& symbol_name,
3742  Symbol condition,
3743  Symbol x,
3744  Symbol y) {
3745  return Operator("where")
3746  .SetInput("condition", condition)
3747  .SetInput("x", x)
3748  .SetInput("y", y)
3749  .CreateSymbol(symbol_name);
3750 }
3751 
3778 inline Symbol smooth_l1(const std::string& symbol_name,
3779  Symbol data,
3780  mx_float scalar) {
3781  return Operator("smooth_l1")
3782  .SetParam("scalar", scalar)
3783  .SetInput("data", data)
3784  .CreateSymbol(symbol_name);
3785 }
3786 
3789 enum class Cast_storageStype {
3790  kCsr = 0,
3791  kDefault = 1,
3792  kRow_sparse = 2
3793 };
3794 
3840 inline Symbol cast_storage(const std::string& symbol_name,
3841  Symbol data,
3842  Cast_storageStype stype) {
3843  static const char *Cast_storageStypeValues[] = {
3844  "csr",
3845  "default",
3846  "row_sparse"
3847  };
3848  return Operator("cast_storage")
3849  .SetParam("stype", Cast_storageStypeValues[int(stype)])
3850  .SetInput("data", data)
3851  .CreateSymbol(symbol_name);
3852 }
3853 
3912 inline Symbol sum(const std::string& symbol_name,
3913  Symbol data,
3914  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
3915  bool keepdims = false,
3916  bool exclude = false) {
3917  return Operator("sum")
3918  .SetParam("axis", axis)
3919  .SetParam("keepdims", keepdims)
3920  .SetParam("exclude", exclude)
3921  .SetInput("data", data)
3922  .CreateSymbol(symbol_name);
3923 }
3924 
3949 inline Symbol mean(const std::string& symbol_name,
3950  Symbol data,
3951  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
3952  bool keepdims = false,
3953  bool exclude = false) {
3954  return Operator("mean")
3955  .SetParam("axis", axis)
3956  .SetParam("keepdims", keepdims)
3957  .SetParam("exclude", exclude)
3958  .SetInput("data", data)
3959  .CreateSymbol(symbol_name);
3960 }
3961 
3986 inline Symbol prod(const std::string& symbol_name,
3987  Symbol data,
3988  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
3989  bool keepdims = false,
3990  bool exclude = false) {
3991  return Operator("prod")
3992  .SetParam("axis", axis)
3993  .SetParam("keepdims", keepdims)
3994  .SetParam("exclude", exclude)
3995  .SetInput("data", data)
3996  .CreateSymbol(symbol_name);
3997 }
3998 
4025 inline Symbol nansum(const std::string& symbol_name,
4026  Symbol data,
4027  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
4028  bool keepdims = false,
4029  bool exclude = false) {
4030  return Operator("nansum")
4031  .SetParam("axis", axis)
4032  .SetParam("keepdims", keepdims)
4033  .SetParam("exclude", exclude)
4034  .SetInput("data", data)
4035  .CreateSymbol(symbol_name);
4036 }
4037 
4064 inline Symbol nanprod(const std::string& symbol_name,
4065  Symbol data,
4066  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
4067  bool keepdims = false,
4068  bool exclude = false) {
4069  return Operator("nanprod")
4070  .SetParam("axis", axis)
4071  .SetParam("keepdims", keepdims)
4072  .SetParam("exclude", exclude)
4073  .SetInput("data", data)
4074  .CreateSymbol(symbol_name);
4075 }
4076 
4101 inline Symbol max(const std::string& symbol_name,
4102  Symbol data,
4103  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
4104  bool keepdims = false,
4105  bool exclude = false) {
4106  return Operator("max")
4107  .SetParam("axis", axis)
4108  .SetParam("keepdims", keepdims)
4109  .SetParam("exclude", exclude)
4110  .SetInput("data", data)
4111  .CreateSymbol(symbol_name);
4112 }
4113 
4138 inline Symbol min(const std::string& symbol_name,
4139  Symbol data,
4140  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
4141  bool keepdims = false,
4142  bool exclude = false) {
4143  return Operator("min")
4144  .SetParam("axis", axis)
4145  .SetParam("keepdims", keepdims)
4146  .SetParam("exclude", exclude)
4147  .SetInput("data", data)
4148  .CreateSymbol(symbol_name);
4149 }
4150 
4180 inline Symbol broadcast_axis(const std::string& symbol_name,
4181  Symbol data,
4182  Shape axis = Shape(),
4183  Shape size = Shape()) {
4184  return Operator("broadcast_axis")
4185  .SetParam("axis", axis)
4186  .SetParam("size", size)
4187  .SetInput("data", data)
4188  .CreateSymbol(symbol_name);
4189 }
4190 
4219 inline Symbol broadcast_to(const std::string& symbol_name,
4220  Symbol data,
4221  Shape shape = Shape()) {
4222  return Operator("broadcast_to")
4223  .SetParam("shape", shape)
4224  .SetInput("data", data)
4225  .CreateSymbol(symbol_name);
4226 }
4227 
4256 inline Symbol broadcast_like(const std::string& symbol_name,
4257  Symbol lhs,
4258  Symbol rhs,
4259  dmlc::optional<Shape> lhs_axes = dmlc::optional<Shape>(),
4260  dmlc::optional<Shape> rhs_axes = dmlc::optional<Shape>()) {
4261  return Operator("broadcast_like")
4262  .SetParam("lhs_axes", lhs_axes)
4263  .SetParam("rhs_axes", rhs_axes)
4264  .SetInput("lhs", lhs)
4265  .SetInput("rhs", rhs)
4266  .CreateSymbol(symbol_name);
4267 }
4268 
4271 enum class NormOutDtype {
4272  kNone = 0,
4273  kFloat16 = 1,
4274  kFloat32 = 2,
4275  kFloat64 = 3,
4276  kInt32 = 4,
4277  kInt64 = 5,
4278  kInt8 = 6
4279 };
4280 
4325 inline Symbol norm(const std::string& symbol_name,
4326  Symbol data,
4327  int ord = 2,
4328  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
4329  NormOutDtype out_dtype = NormOutDtype::kNone,
4330  bool keepdims = false) {
4331  static const char *NormOutDtypeValues[] = {
4332  "None",
4333  "float16",
4334  "float32",
4335  "float64",
4336  "int32",
4337  "int64",
4338  "int8"
4339  };
4340  return Operator("norm")
4341  .SetParam("ord", ord)
4342  .SetParam("axis", axis)
4343  .SetParam("out_dtype", NormOutDtypeValues[int(out_dtype)])
4344  .SetParam("keepdims", keepdims)
4345  .SetInput("data", data)
4346  .CreateSymbol(symbol_name);
4347 }
4348 
4370 inline Symbol sin(const std::string& symbol_name,
4371  Symbol data) {
4372  return Operator("sin")
4373  .SetInput("data", data)
4374  .CreateSymbol(symbol_name);
4375 }
4376 
4394 inline Symbol cos(const std::string& symbol_name,
4395  Symbol data) {
4396  return Operator("cos")
4397  .SetInput("data", data)
4398  .CreateSymbol(symbol_name);
4399 }
4400 
4422 inline Symbol tan(const std::string& symbol_name,
4423  Symbol data) {
4424  return Operator("tan")
4425  .SetInput("data", data)
4426  .CreateSymbol(symbol_name);
4427 }
4428 
4451 inline Symbol arcsin(const std::string& symbol_name,
4452  Symbol data) {
4453  return Operator("arcsin")
4454  .SetInput("data", data)
4455  .CreateSymbol(symbol_name);
4456 }
4457 
4476 inline Symbol arccos(const std::string& symbol_name,
4477  Symbol data) {
4478  return Operator("arccos")
4479  .SetInput("data", data)
4480  .CreateSymbol(symbol_name);
4481 }
4482 
4504 inline Symbol arctan(const std::string& symbol_name,
4505  Symbol data) {
4506  return Operator("arctan")
4507  .SetInput("data", data)
4508  .CreateSymbol(symbol_name);
4509 }
4510 
4530 inline Symbol degrees(const std::string& symbol_name,
4531  Symbol data) {
4532  return Operator("degrees")
4533  .SetInput("data", data)
4534  .CreateSymbol(symbol_name);
4535 }
4536 
4556 inline Symbol radians(const std::string& symbol_name,
4557  Symbol data) {
4558  return Operator("radians")
4559  .SetInput("data", data)
4560  .CreateSymbol(symbol_name);
4561 }
4562 
4582 inline Symbol sinh(const std::string& symbol_name,
4583  Symbol data) {
4584  return Operator("sinh")
4585  .SetInput("data", data)
4586  .CreateSymbol(symbol_name);
4587 }
4588 
4604 inline Symbol cosh(const std::string& symbol_name,
4605  Symbol data) {
4606  return Operator("cosh")
4607  .SetInput("data", data)
4608  .CreateSymbol(symbol_name);
4609 }
4610 
4630 inline Symbol tanh(const std::string& symbol_name,
4631  Symbol data) {
4632  return Operator("tanh")
4633  .SetInput("data", data)
4634  .CreateSymbol(symbol_name);
4635 }
4636 
4654 inline Symbol arcsinh(const std::string& symbol_name,
4655  Symbol data) {
4656  return Operator("arcsinh")
4657  .SetInput("data", data)
4658  .CreateSymbol(symbol_name);
4659 }
4660 
4674 inline Symbol arccosh(const std::string& symbol_name,
4675  Symbol data) {
4676  return Operator("arccosh")
4677  .SetInput("data", data)
4678  .CreateSymbol(symbol_name);
4679 }
4680 
4698 inline Symbol arctanh(const std::string& symbol_name,
4699  Symbol data) {
4700  return Operator("arctanh")
4701  .SetInput("data", data)
4702  .CreateSymbol(symbol_name);
4703 }
4704 
4707 enum class PoolingPoolType {
4708  kAvg = 0,
4709  kLp = 1,
4710  kMax = 2,
4711  kSum = 3
4712 };
4713 
4717  kFull = 0,
4718  kSame = 1,
4719  kValid = 2
4720 };
4721 
4725 enum class PoolingLayout {
4726  kNone = 0,
4727  kNCDHW = 1,
4728  kNCHW = 2,
4729  kNCW = 3,
4730  kNDHWC = 4,
4731  kNHWC = 5,
4732  kNWC = 6
4733 };
4734 
4805 inline Symbol Pooling(const std::string& symbol_name,
4806  Symbol data,
4807  Shape kernel = Shape(),
4809  bool global_pool = false,
4810  bool cudnn_off = false,
4812  Shape stride = Shape(),
4813  Shape pad = Shape(),
4814  dmlc::optional<int> p_value = dmlc::optional<int>(),
4815  dmlc::optional<bool> count_include_pad = dmlc::optional<bool>(),
4817  static const char *PoolingPoolTypeValues[] = {
4818  "avg",
4819  "lp",
4820  "max",
4821  "sum"
4822  };
4823  static const char *PoolingPoolingConventionValues[] = {
4824  "full",
4825  "same",
4826  "valid"
4827  };
4828  static const char *PoolingLayoutValues[] = {
4829  "None",
4830  "NCDHW",
4831  "NCHW",
4832  "NCW",
4833  "NDHWC",
4834  "NHWC",
4835  "NWC"
4836  };
4837  return Operator("Pooling")
4838  .SetParam("kernel", kernel)
4839  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
4840  .SetParam("global_pool", global_pool)
4841  .SetParam("cudnn_off", cudnn_off)
4842  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
4843  .SetParam("stride", stride)
4844  .SetParam("pad", pad)
4845  .SetParam("p_value", p_value)
4846  .SetParam("count_include_pad", count_include_pad)
4847  .SetParam("layout", PoolingLayoutValues[int(layout)])
4848  .SetInput("data", data)
4849  .CreateSymbol(symbol_name);
4850 }
4851 
4854 enum class SoftmaxDtype {
4855  kNone = 0,
4856  kFloat16 = 1,
4857  kFloat32 = 2,
4858  kFloat64 = 3
4859 };
4860 
4894 inline Symbol softmax(const std::string& symbol_name,
4895  Symbol data,
4896  int axis = -1,
4897  dmlc::optional<double> temperature = dmlc::optional<double>(),
4899  static const char *SoftmaxDtypeValues[] = {
4900  "None",
4901  "float16",
4902  "float32",
4903  "float64"
4904  };
4905  return Operator("softmax")
4906  .SetParam("axis", axis)
4907  .SetParam("temperature", temperature)
4908  .SetParam("dtype", SoftmaxDtypeValues[int(dtype)])
4909  .SetInput("data", data)
4910  .CreateSymbol(symbol_name);
4911 }
4912 
4915 enum class SoftminDtype {
4916  kNone = 0,
4917  kFloat16 = 1,
4918  kFloat32 = 2,
4919  kFloat64 = 3
4920 };
4921 
4956 inline Symbol softmin(const std::string& symbol_name,
4957  Symbol data,
4958  int axis = -1,
4959  dmlc::optional<double> temperature = dmlc::optional<double>(),
4961  static const char *SoftminDtypeValues[] = {
4962  "None",
4963  "float16",
4964  "float32",
4965  "float64"
4966  };
4967  return Operator("softmin")
4968  .SetParam("axis", axis)
4969  .SetParam("temperature", temperature)
4970  .SetParam("dtype", SoftminDtypeValues[int(dtype)])
4971  .SetInput("data", data)
4972  .CreateSymbol(symbol_name);
4973 }
4974 
4977 enum class Log_softmaxDtype {
4978  kNone = 0,
4979  kFloat16 = 1,
4980  kFloat32 = 2,
4981  kFloat64 = 3
4982 };
4983 
5008 inline Symbol log_softmax(const std::string& symbol_name,
5009  Symbol data,
5010  int axis = -1,
5011  dmlc::optional<double> temperature = dmlc::optional<double>(),
5013  static const char *Log_softmaxDtypeValues[] = {
5014  "None",
5015  "float16",
5016  "float32",
5017  "float64"
5018  };
5019  return Operator("log_softmax")
5020  .SetParam("axis", axis)
5021  .SetParam("temperature", temperature)
5022  .SetParam("dtype", Log_softmaxDtypeValues[int(dtype)])
5023  .SetInput("data", data)
5024  .CreateSymbol(symbol_name);
5025 }
5026 
5030  kNone = 0,
5031  kFastest = 1,
5032  kLimited_workspace = 2,
5033  kOff = 3
5034 };
5035 
5039  kNone = 0,
5040  kNCDHW = 1,
5041  kNCHW = 2,
5042  kNCW = 3,
5043  kNDHWC = 4,
5044  kNHWC = 5
5045 };
5046 
5076 inline Symbol Deconvolution(const std::string& symbol_name,
5077  Symbol data,
5078  Symbol weight,
5079  Symbol bias,
5080  Shape kernel,
5081  uint32_t num_filter,
5082  Shape stride = Shape(),
5083  Shape dilate = Shape(),
5084  Shape pad = Shape(),
5085  Shape adj = Shape(),
5086  Shape target_shape = Shape(),
5087  uint32_t num_group = 1,
5088  uint64_t workspace = 512,
5089  bool no_bias = true,
5091  bool cudnn_off = false,
5093  static const char *DeconvolutionCudnnTuneValues[] = {
5094  "None",
5095  "fastest",
5096  "limited_workspace",
5097  "off"
5098  };
5099  static const char *DeconvolutionLayoutValues[] = {
5100  "None",
5101  "NCDHW",
5102  "NCHW",
5103  "NCW",
5104  "NDHWC",
5105  "NHWC"
5106  };
5107  return Operator("Deconvolution")
5108  .SetParam("kernel", kernel)
5109  .SetParam("num_filter", num_filter)
5110  .SetParam("stride", stride)
5111  .SetParam("dilate", dilate)
5112  .SetParam("pad", pad)
5113  .SetParam("adj", adj)
5114  .SetParam("target_shape", target_shape)
5115  .SetParam("num_group", num_group)
5116  .SetParam("workspace", workspace)
5117  .SetParam("no_bias", no_bias)
5118  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
5119  .SetParam("cudnn_off", cudnn_off)
5120  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
5121  .SetInput("data", data)
5122  .SetInput("weight", weight)
5123  .SetInput("bias", bias)
5124  .CreateSymbol(symbol_name);
5125 }
5126 
5129 enum class ActivationActType {
5130  kRelu = 0,
5131  kSigmoid = 1,
5132  kSoftrelu = 2,
5133  kSoftsign = 3,
5134  kTanh = 4
5135 };
5136 
5156 inline Symbol Activation(const std::string& symbol_name,
5157  Symbol data,
5158  ActivationActType act_type) {
5159  static const char *ActivationActTypeValues[] = {
5160  "relu",
5161  "sigmoid",
5162  "softrelu",
5163  "softsign",
5164  "tanh"
5165  };
5166  return Operator("Activation")
5167  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
5168  .SetInput("data", data)
5169  .CreateSymbol(symbol_name);
5170 }
5171 
5240 inline Symbol BatchNorm(const std::string& symbol_name,
5241  Symbol data,
5242  Symbol gamma,
5243  Symbol beta,
5244  Symbol moving_mean,
5245  Symbol moving_var,
5246  double eps = 0.0010000000474974513,
5247  mx_float momentum = 0.899999976,
5248  bool fix_gamma = true,
5249  bool use_global_stats = false,
5250  bool output_mean_var = false,
5251  int axis = 1,
5252  bool cudnn_off = false) {
5253  return Operator("BatchNorm")
5254  .SetParam("eps", eps)
5255  .SetParam("momentum", momentum)
5256  .SetParam("fix_gamma", fix_gamma)
5257  .SetParam("use_global_stats", use_global_stats)
5258  .SetParam("output_mean_var", output_mean_var)
5259  .SetParam("axis", axis)
5260  .SetParam("cudnn_off", cudnn_off)
5261  .SetInput("data", data)
5262  .SetInput("gamma", gamma)
5263  .SetInput("beta", beta)
5264  .SetInput("moving_mean", moving_mean)
5265  .SetInput("moving_var", moving_var)
5266  .CreateSymbol(symbol_name);
5267 }
5268 
5275 enum class CTCLossBlankLabel {
5276  kFirst = 0,
5277  kLast = 1
5278 };
5279 
5346 inline Symbol CTCLoss(const std::string& symbol_name,
5347  Symbol data,
5348  Symbol label,
5349  Symbol data_lengths,
5350  Symbol label_lengths,
5351  bool use_data_lengths = false,
5352  bool use_label_lengths = false,
5354  static const char *CTCLossBlankLabelValues[] = {
5355  "first",
5356  "last"
5357  };
5358  return Operator("CTCLoss")
5359  .SetParam("use_data_lengths", use_data_lengths)
5360  .SetParam("use_label_lengths", use_label_lengths)
5361  .SetParam("blank_label", CTCLossBlankLabelValues[int(blank_label)])
5362  .SetInput("data", data)
5363  .SetInput("label", label)
5364  .SetInput("data_lengths", data_lengths)
5365  .SetInput("label_lengths", label_lengths)
5366  .CreateSymbol(symbol_name);
5367 }
5368 
5412 inline Symbol FullyConnected(const std::string& symbol_name,
5413  Symbol data,
5414  Symbol weight,
5415  Symbol bias,
5416  int num_hidden,
5417  bool no_bias = false,
5418  bool flatten = true) {
5419  return Operator("FullyConnected")
5420  .SetParam("num_hidden", num_hidden)
5421  .SetParam("no_bias", no_bias)
5422  .SetParam("flatten", flatten)
5423  .SetInput("data", data)
5424  .SetInput("weight", weight)
5425  .SetInput("bias", bias)
5426  .CreateSymbol(symbol_name);
5427 }
5428 
5432  kNone = 0,
5433  kFastest = 1,
5434  kLimited_workspace = 2,
5435  kOff = 3
5436 };
5437 
5441 enum class ConvolutionLayout {
5442  kNone = 0,
5443  kNCDHW = 1,
5444  kNCHW = 2,
5445  kNCW = 3,
5446  kNDHWC = 4,
5447  kNHWC = 5
5448 };
5449 
5547 inline Symbol Convolution(const std::string& symbol_name,
5548  Symbol data,
5549  Symbol weight,
5550  Symbol bias,
5551  Shape kernel,
5552  uint32_t num_filter,
5553  Shape stride = Shape(),
5554  Shape dilate = Shape(),
5555  Shape pad = Shape(),
5556  uint32_t num_group = 1,
5557  uint64_t workspace = 1024,
5558  bool no_bias = false,
5560  bool cudnn_off = false,
5562  static const char *ConvolutionCudnnTuneValues[] = {
5563  "None",
5564  "fastest",
5565  "limited_workspace",
5566  "off"
5567  };
5568  static const char *ConvolutionLayoutValues[] = {
5569  "None",
5570  "NCDHW",
5571  "NCHW",
5572  "NCW",
5573  "NDHWC",
5574  "NHWC"
5575  };
5576  return Operator("Convolution")
5577  .SetParam("kernel", kernel)
5578  .SetParam("num_filter", num_filter)
5579  .SetParam("stride", stride)
5580  .SetParam("dilate", dilate)
5581  .SetParam("pad", pad)
5582  .SetParam("num_group", num_group)
5583  .SetParam("workspace", workspace)
5584  .SetParam("no_bias", no_bias)
5585  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
5586  .SetParam("cudnn_off", cudnn_off)
5587  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
5588  .SetInput("data", data)
5589  .SetInput("weight", weight)
5590  .SetInput("bias", bias)
5591  .CreateSymbol(symbol_name);
5592 }
5593 
5597  kBilinear = 0,
5598  kNearest = 1
5599 };
5600 
5605  kConcat = 0,
5606  kSum = 1
5607 };
5608 
5675 inline Symbol UpSampling(const std::string& symbol_name,
5676  const std::vector<Symbol>& data,
5677  int scale,
5678  UpSamplingSampleType sample_type,
5679  int num_args,
5680  int num_filter = 0,
5682  uint64_t workspace = 512) {
5683  static const char *UpSamplingSampleTypeValues[] = {
5684  "bilinear",
5685  "nearest"
5686  };
5687  static const char *UpSamplingMultiInputModeValues[] = {
5688  "concat",
5689  "sum"
5690  };
5691  return Operator("UpSampling")
5692  .SetParam("scale", scale)
5693  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
5694  .SetParam("num_args", num_args)
5695  .SetParam("num_filter", num_filter)
5696  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
5697  .SetParam("workspace", workspace)
5698 (data)
5699  .CreateSymbol(symbol_name);
5700 }
5701 
5748 inline Symbol Concat(const std::string& symbol_name,
5749  const std::vector<Symbol>& data,
5750  int num_args,
5751  int dim = 1) {
5752  return Operator("Concat")
5753  .SetParam("num_args", num_args)
5754  .SetParam("dim", dim)
5755 (data)
5756  .CreateSymbol(symbol_name);
5757 }
5758 
5797 inline Symbol LayerNorm(const std::string& symbol_name,
5798  Symbol data,
5799  Symbol gamma,
5800  Symbol beta,
5801  int axis = -1,
5802  mx_float eps = 9.99999975e-06,
5803  bool output_mean_var = false) {
5804  return Operator("LayerNorm")
5805  .SetParam("axis", axis)
5806  .SetParam("eps", eps)
5807  .SetParam("output_mean_var", output_mean_var)
5808  .SetInput("data", data)
5809  .SetInput("gamma", gamma)
5810  .SetInput("beta", beta)
5811  .CreateSymbol(symbol_name);
5812 }
5813 
5841 inline Symbol LRN(const std::string& symbol_name,
5842  Symbol data,
5843  uint32_t nsize,
5844  mx_float alpha = 9.99999975e-05,
5845  mx_float beta = 0.75,
5846  mx_float knorm = 2) {
5847  return Operator("LRN")
5848  .SetParam("nsize", nsize)
5849  .SetParam("alpha", alpha)
5850  .SetParam("beta", beta)
5851  .SetParam("knorm", knorm)
5852  .SetInput("data", data)
5853  .CreateSymbol(symbol_name);
5854 }
5855 
5858 enum class DropoutMode {
5859  kAlways = 0,
5860  kTraining = 1
5861 };
5862 
5904 inline Symbol Dropout(const std::string& symbol_name,
5905  Symbol data,
5906  mx_float p = 0.5,
5908  Shape axes = Shape(),
5909  dmlc::optional<bool> cudnn_off = dmlc::optional<bool>(0)) {
5910  static const char *DropoutModeValues[] = {
5911  "always",
5912  "training"
5913  };
5914  return Operator("Dropout")
5915  .SetParam("p", p)
5916  .SetParam("mode", DropoutModeValues[int(mode)])
5917  .SetParam("axes", axes)
5918  .SetParam("cudnn_off", cudnn_off)
5919  .SetInput("data", data)
5920  .CreateSymbol(symbol_name);
5921 }
5922 
5927  kChannel = 0,
5928  kInstance = 1
5929 };
5930 
5964 inline Symbol SoftmaxActivation(const std::string& symbol_name,
5965  Symbol data,
5967  static const char *SoftmaxActivationModeValues[] = {
5968  "channel",
5969  "instance"
5970  };
5971  return Operator("SoftmaxActivation")
5972  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
5973  .SetInput("data", data)
5974  .CreateSymbol(symbol_name);
5975 }
5976 
6006 inline Symbol moments(const std::string& symbol_name,
6007  Symbol data,
6008  dmlc::optional<Shape> axes = dmlc::optional<Shape>(),
6009  bool keepdims = false) {
6010  return Operator("moments")
6011  .SetParam("axes", axes)
6012  .SetParam("keepdims", keepdims)
6013  .SetInput("data", data)
6014  .CreateSymbol(symbol_name);
6015 }
6016 
6019 enum class LeakyReLUActType {
6020  kElu = 0,
6021  kGelu = 1,
6022  kLeaky = 2,
6023  kPrelu = 3,
6024  kRrelu = 4,
6025  kSelu = 5
6026 };
6027 
6058 inline Symbol LeakyReLU(const std::string& symbol_name,
6059  Symbol data,
6060  Symbol gamma,
6062  mx_float slope = 0.25,
6063  mx_float lower_bound = 0.125,
6064  mx_float upper_bound = 0.333999991) {
6065  static const char *LeakyReLUActTypeValues[] = {
6066  "elu",
6067  "gelu",
6068  "leaky",
6069  "prelu",
6070  "rrelu",
6071  "selu"
6072  };
6073  return Operator("LeakyReLU")
6074  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
6075  .SetParam("slope", slope)
6076  .SetParam("lower_bound", lower_bound)
6077  .SetParam("upper_bound", upper_bound)
6078  .SetInput("data", data)
6079  .SetInput("gamma", gamma)
6080  .CreateSymbol(symbol_name);
6081 }
6082 
6085 enum class RNNMode {
6086  kGru = 0,
6087  kLstm = 1,
6088  kRnn_relu = 2,
6089  kRnn_tanh = 3
6090 };
6091 
6168 inline Symbol RNN(const std::string& symbol_name,
6169  Symbol data,
6170  Symbol parameters,
6171  Symbol state,
6172  Symbol state_cell,
6173  Symbol sequence_length,
6174  uint32_t state_size,
6175  uint32_t num_layers,
6176  RNNMode mode,
6177  bool bidirectional = false,
6178  mx_float p = 0,
6179  bool state_outputs = false,
6180  dmlc::optional<int> projection_size = dmlc::optional<int>(),
6181  dmlc::optional<double> lstm_state_clip_min = dmlc::optional<double>(),
6182  dmlc::optional<double> lstm_state_clip_max = dmlc::optional<double>(),
6183  bool lstm_state_clip_nan = false,
6184  bool use_sequence_length = false) {
6185  static const char *RNNModeValues[] = {
6186  "gru",
6187  "lstm",
6188  "rnn_relu",
6189  "rnn_tanh"
6190  };
6191  return Operator("RNN")
6192  .SetParam("state_size", state_size)
6193  .SetParam("num_layers", num_layers)
6194  .SetParam("mode", RNNModeValues[int(mode)])
6195  .SetParam("bidirectional", bidirectional)
6196  .SetParam("p", p)
6197  .SetParam("state_outputs", state_outputs)
6198  .SetParam("projection_size", projection_size)
6199  .SetParam("lstm_state_clip_min", lstm_state_clip_min)
6200  .SetParam("lstm_state_clip_max", lstm_state_clip_max)
6201  .SetParam("lstm_state_clip_nan", lstm_state_clip_nan)
6202  .SetParam("use_sequence_length", use_sequence_length)
6203  .SetInput("data", data)
6204  .SetInput("parameters", parameters)
6205  .SetInput("state", state)
6206  .SetInput("state_cell", state_cell)
6207  .SetInput("sequence_length", sequence_length)
6208  .CreateSymbol(symbol_name);
6209 }
6210 
6214  kBatch = 0,
6215  kNull = 1,
6216  kValid = 2
6217 };
6218 
6314 inline Symbol SoftmaxOutput(const std::string& symbol_name,
6315  Symbol data,
6316  Symbol label,
6317  mx_float grad_scale = 1,
6318  mx_float ignore_label = -1,
6319  bool multi_output = false,
6320  bool use_ignore = false,
6321  bool preserve_shape = false,
6323  bool out_grad = false,
6324  mx_float smooth_alpha = 0) {
6325  static const char *SoftmaxOutputNormalizationValues[] = {
6326  "batch",
6327  "null",
6328  "valid"
6329  };
6330  return Operator("SoftmaxOutput")
6331  .SetParam("grad_scale", grad_scale)
6332  .SetParam("ignore_label", ignore_label)
6333  .SetParam("multi_output", multi_output)
6334  .SetParam("use_ignore", use_ignore)
6335  .SetParam("preserve_shape", preserve_shape)
6336  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
6337  .SetParam("out_grad", out_grad)
6338  .SetParam("smooth_alpha", smooth_alpha)
6339  .SetInput("data", data)
6340  .SetInput("label", label)
6341  .CreateSymbol(symbol_name);
6342 }
6343 
6372 inline Symbol SwapAxis(const std::string& symbol_name,
6373  Symbol data,
6374  uint32_t dim1 = 0,
6375  uint32_t dim2 = 0) {
6376  return Operator("SwapAxis")
6377  .SetParam("dim1", dim1)
6378  .SetParam("dim2", dim2)
6379  .SetInput("data", data)
6380  .CreateSymbol(symbol_name);
6381 }
6382 
6443 inline Symbol BatchNorm_v1(const std::string& symbol_name,
6444  Symbol data,
6445  Symbol gamma,
6446  Symbol beta,
6447  mx_float eps = 0.00100000005,
6448  mx_float momentum = 0.899999976,
6449  bool fix_gamma = true,
6450  bool use_global_stats = false,
6451  bool output_mean_var = false) {
6452  return Operator("BatchNorm_v1")
6453  .SetParam("eps", eps)
6454  .SetParam("momentum", momentum)
6455  .SetParam("fix_gamma", fix_gamma)
6456  .SetParam("use_global_stats", use_global_stats)
6457  .SetParam("output_mean_var", output_mean_var)
6458  .SetInput("data", data)
6459  .SetInput("gamma", gamma)
6460  .SetInput("beta", beta)
6461  .CreateSymbol(symbol_name);
6462 }
6463 
6501 inline Symbol softmax_cross_entropy(const std::string& symbol_name,
6502  Symbol data,
6503  Symbol label) {
6504  return Operator("softmax_cross_entropy")
6505  .SetInput("data", data)
6506  .SetInput("label", label)
6507  .CreateSymbol(symbol_name);
6508 }
6509 
6539 inline Symbol LinearRegressionOutput(const std::string& symbol_name,
6540  Symbol data,
6541  Symbol label,
6542  mx_float grad_scale = 1) {
6543  return Operator("LinearRegressionOutput")
6544  .SetParam("grad_scale", grad_scale)
6545  .SetInput("data", data)
6546  .SetInput("label", label)
6547  .CreateSymbol(symbol_name);
6548 }
6549 
6580 inline Symbol MAERegressionOutput(const std::string& symbol_name,
6581  Symbol data,
6582  Symbol label,
6583  mx_float grad_scale = 1) {
6584  return Operator("MAERegressionOutput")
6585  .SetParam("grad_scale", grad_scale)
6586  .SetInput("data", data)
6587  .SetInput("label", label)
6588  .CreateSymbol(symbol_name);
6589 }
6590 
6627 inline Symbol LogisticRegressionOutput(const std::string& symbol_name,
6628  Symbol data,
6629  Symbol label,
6630  mx_float grad_scale = 1) {
6631  return Operator("LogisticRegressionOutput")
6632  .SetParam("grad_scale", grad_scale)
6633  .SetInput("data", data)
6634  .SetInput("label", label)
6635  .CreateSymbol(symbol_name);
6636 }
6637 
6647 inline Symbol IdentityAttachKLSparseReg(const std::string& symbol_name,
6648  Symbol data,
6649  mx_float sparseness_target = 0.100000001,
6650  mx_float penalty = 0.00100000005,
6651  mx_float momentum = 0.899999976) {
6652  return Operator("IdentityAttachKLSparseReg")
6653  .SetParam("sparseness_target", sparseness_target)
6654  .SetParam("penalty", penalty)
6655  .SetParam("momentum", momentum)
6656  .SetInput("data", data)
6657  .CreateSymbol(symbol_name);
6658 }
6659 
6688 inline Symbol signsgd_update(const std::string& symbol_name,
6689  Symbol weight,
6690  Symbol grad,
6691  mx_float lr,
6692  mx_float wd = 0,
6693  mx_float rescale_grad = 1,
6694  mx_float clip_gradient = -1) {
6695  return Operator("signsgd_update")
6696  .SetParam("lr", lr)
6697  .SetParam("wd", wd)
6698  .SetParam("rescale_grad", rescale_grad)
6699  .SetParam("clip_gradient", clip_gradient)
6700  .SetInput("weight", weight)
6701  .SetInput("grad", grad)
6702  .CreateSymbol(symbol_name);
6703 }
6704 
6739 inline Symbol signum_update(const std::string& symbol_name,
6740  Symbol weight,
6741  Symbol grad,
6742  Symbol mom,
6743  mx_float lr,
6744  mx_float momentum = 0,
6745  mx_float wd = 0,
6746  mx_float rescale_grad = 1,
6747  mx_float clip_gradient = -1,
6748  mx_float wd_lh = 0) {
6749  return Operator("signum_update")
6750  .SetParam("lr", lr)
6751  .SetParam("momentum", momentum)
6752  .SetParam("wd", wd)
6753  .SetParam("rescale_grad", rescale_grad)
6754  .SetParam("clip_gradient", clip_gradient)
6755  .SetParam("wd_lh", wd_lh)
6756  .SetInput("weight", weight)
6757  .SetInput("grad", grad)
6758  .SetInput("mom", mom)
6759  .CreateSymbol(symbol_name);
6760 }
6761 
6783 inline Symbol multi_sgd_update(const std::string& symbol_name,
6784  const std::vector<Symbol>& data,
6785  nnvm::Tuple<mx_float> lrs,
6786  nnvm::Tuple<mx_float> wds,
6787  mx_float rescale_grad = 1,
6788  mx_float clip_gradient = -1,
6789  int num_weights = 1) {
6790  return Operator("multi_sgd_update")
6791  .SetParam("lrs", lrs)
6792  .SetParam("wds", wds)
6793  .SetParam("rescale_grad", rescale_grad)
6794  .SetParam("clip_gradient", clip_gradient)
6795  .SetParam("num_weights", num_weights)
6796 (data)
6797  .CreateSymbol(symbol_name);
6798 }
6799 
6834 inline Symbol multi_sgd_mom_update(const std::string& symbol_name,
6835  const std::vector<Symbol>& data,
6836  nnvm::Tuple<mx_float> lrs,
6837  nnvm::Tuple<mx_float> wds,
6838  mx_float momentum = 0,
6839  mx_float rescale_grad = 1,
6840  mx_float clip_gradient = -1,
6841  int num_weights = 1) {
6842  return Operator("multi_sgd_mom_update")
6843  .SetParam("lrs", lrs)
6844  .SetParam("wds", wds)
6845  .SetParam("momentum", momentum)
6846  .SetParam("rescale_grad", rescale_grad)
6847  .SetParam("clip_gradient", clip_gradient)
6848  .SetParam("num_weights", num_weights)
6849 (data)
6850  .CreateSymbol(symbol_name);
6851 }
6852 
6874 inline Symbol multi_mp_sgd_update(const std::string& symbol_name,
6875  const std::vector<Symbol>& data,
6876  nnvm::Tuple<mx_float> lrs,
6877  nnvm::Tuple<mx_float> wds,
6878  mx_float rescale_grad = 1,
6879  mx_float clip_gradient = -1,
6880  int num_weights = 1) {
6881  return Operator("multi_mp_sgd_update")
6882  .SetParam("lrs", lrs)
6883  .SetParam("wds", wds)
6884  .SetParam("rescale_grad", rescale_grad)
6885  .SetParam("clip_gradient", clip_gradient)
6886  .SetParam("num_weights", num_weights)
6887 (data)
6888  .CreateSymbol(symbol_name);
6889 }
6890 
6925 inline Symbol multi_mp_sgd_mom_update(const std::string& symbol_name,
6926  const std::vector<Symbol>& data,
6927  nnvm::Tuple<mx_float> lrs,
6928  nnvm::Tuple<mx_float> wds,
6929  mx_float momentum = 0,
6930  mx_float rescale_grad = 1,
6931  mx_float clip_gradient = -1,
6932  int num_weights = 1) {
6933  return Operator("multi_mp_sgd_mom_update")
6934  .SetParam("lrs", lrs)
6935  .SetParam("wds", wds)
6936  .SetParam("momentum", momentum)
6937  .SetParam("rescale_grad", rescale_grad)
6938  .SetParam("clip_gradient", clip_gradient)
6939  .SetParam("num_weights", num_weights)
6940 (data)
6941  .CreateSymbol(symbol_name);
6942 }
6943 
6972 inline Symbol sgd_update(const std::string& symbol_name,
6973  Symbol weight,
6974  Symbol grad,
6975  mx_float lr,
6976  mx_float wd = 0,
6977  mx_float rescale_grad = 1,
6978  mx_float clip_gradient = -1,
6979  bool lazy_update = true) {
6980  return Operator("sgd_update")
6981  .SetParam("lr", lr)
6982  .SetParam("wd", wd)
6983  .SetParam("rescale_grad", rescale_grad)
6984  .SetParam("clip_gradient", clip_gradient)
6985  .SetParam("lazy_update", lazy_update)
6986  .SetInput("weight", weight)
6987  .SetInput("grad", grad)
6988  .CreateSymbol(symbol_name);
6989 }
6990 
7035 inline Symbol sgd_mom_update(const std::string& symbol_name,
7036  Symbol weight,
7037  Symbol grad,
7038  Symbol mom,
7039  mx_float lr,
7040  mx_float momentum = 0,
7041  mx_float wd = 0,
7042  mx_float rescale_grad = 1,
7043  mx_float clip_gradient = -1,
7044  bool lazy_update = true) {
7045  return Operator("sgd_mom_update")
7046  .SetParam("lr", lr)
7047  .SetParam("momentum", momentum)
7048  .SetParam("wd", wd)
7049  .SetParam("rescale_grad", rescale_grad)
7050  .SetParam("clip_gradient", clip_gradient)
7051  .SetParam("lazy_update", lazy_update)
7052  .SetInput("weight", weight)
7053  .SetInput("grad", grad)
7054  .SetInput("mom", mom)
7055  .CreateSymbol(symbol_name);
7056 }
7057 
7073 inline Symbol mp_sgd_update(const std::string& symbol_name,
7074  Symbol weight,
7075  Symbol grad,
7076  Symbol weight32,
7077  mx_float lr,
7078  mx_float wd = 0,
7079  mx_float rescale_grad = 1,
7080  mx_float clip_gradient = -1,
7081  bool lazy_update = true) {
7082  return Operator("mp_sgd_update")
7083  .SetParam("lr", lr)
7084  .SetParam("wd", wd)
7085  .SetParam("rescale_grad", rescale_grad)
7086  .SetParam("clip_gradient", clip_gradient)
7087  .SetParam("lazy_update", lazy_update)
7088  .SetInput("weight", weight)
7089  .SetInput("grad", grad)
7090  .SetInput("weight32", weight32)
7091  .CreateSymbol(symbol_name);
7092 }
7093 
7111 inline Symbol mp_sgd_mom_update(const std::string& symbol_name,
7112  Symbol weight,
7113  Symbol grad,
7114  Symbol mom,
7115  Symbol weight32,
7116  mx_float lr,
7117  mx_float momentum = 0,
7118  mx_float wd = 0,
7119  mx_float rescale_grad = 1,
7120  mx_float clip_gradient = -1,
7121  bool lazy_update = true) {
7122  return Operator("mp_sgd_mom_update")
7123  .SetParam("lr", lr)
7124  .SetParam("momentum", momentum)
7125  .SetParam("wd", wd)
7126  .SetParam("rescale_grad", rescale_grad)
7127  .SetParam("clip_gradient", clip_gradient)
7128  .SetParam("lazy_update", lazy_update)
7129  .SetInput("weight", weight)
7130  .SetInput("grad", grad)
7131  .SetInput("mom", mom)
7132  .SetInput("weight32", weight32)
7133  .CreateSymbol(symbol_name);
7134 }
7135 
7171 inline Symbol ftml_update(const std::string& symbol_name,
7172  Symbol weight,
7173  Symbol grad,
7174  Symbol d,
7175  Symbol v,
7176  Symbol z,
7177  mx_float lr,
7178  int t,
7179  mx_float beta1 = 0.600000024,
7180  mx_float beta2 = 0.999000013,
7181  double epsilon = 9.9999999392252903e-09,
7182  mx_float wd = 0,
7183  mx_float rescale_grad = 1,
7184  mx_float clip_grad = -1) {
7185  return Operator("ftml_update")
7186  .SetParam("lr", lr)
7187  .SetParam("t", t)
7188  .SetParam("beta1", beta1)
7189  .SetParam("beta2", beta2)
7190  .SetParam("epsilon", epsilon)
7191  .SetParam("wd", wd)
7192  .SetParam("rescale_grad", rescale_grad)
7193  .SetParam("clip_grad", clip_grad)
7194  .SetInput("weight", weight)
7195  .SetInput("grad", grad)
7196  .SetInput("d", d)
7197  .SetInput("v", v)
7198  .SetInput("z", z)
7199  .CreateSymbol(symbol_name);
7200 }
7201 
7251 inline Symbol adam_update(const std::string& symbol_name,
7252  Symbol weight,
7253  Symbol grad,
7254  Symbol mean,
7255  Symbol var,
7256  mx_float lr,
7257  mx_float beta1 = 0.899999976,
7258  mx_float beta2 = 0.999000013,
7259  mx_float epsilon = 9.99999994e-09,
7260  mx_float wd = 0,
7261  mx_float rescale_grad = 1,
7262  mx_float clip_gradient = -1,
7263  bool lazy_update = true) {
7264  return Operator("adam_update")
7265  .SetParam("lr", lr)
7266  .SetParam("beta1", beta1)
7267  .SetParam("beta2", beta2)
7268  .SetParam("epsilon", epsilon)
7269  .SetParam("wd", wd)
7270  .SetParam("rescale_grad", rescale_grad)
7271  .SetParam("clip_gradient", clip_gradient)
7272  .SetParam("lazy_update", lazy_update)
7273  .SetInput("weight", weight)
7274  .SetInput("grad", grad)
7275  .SetInput("mean", mean)
7276  .SetInput("var", var)
7277  .CreateSymbol(symbol_name);
7278 }
7279 
7310 inline Symbol nag_mom_update(const std::string& symbol_name,
7311  Symbol weight,
7312  Symbol grad,
7313  Symbol mom,
7314  mx_float lr,
7315  mx_float momentum = 0,
7316  mx_float wd = 0,
7317  mx_float rescale_grad = 1,
7318  mx_float clip_gradient = -1) {
7319  return Operator("nag_mom_update")
7320  .SetParam("lr", lr)
7321  .SetParam("momentum", momentum)
7322  .SetParam("wd", wd)
7323  .SetParam("rescale_grad", rescale_grad)
7324  .SetParam("clip_gradient", clip_gradient)
7325  .SetInput("weight", weight)
7326  .SetInput("grad", grad)
7327  .SetInput("mom", mom)
7328  .CreateSymbol(symbol_name);
7329 }
7330 
7350 inline Symbol mp_nag_mom_update(const std::string& symbol_name,
7351  Symbol weight,
7352  Symbol grad,
7353  Symbol mom,
7354  Symbol weight32,
7355  mx_float lr,
7356  mx_float momentum = 0,
7357  mx_float wd = 0,
7358  mx_float rescale_grad = 1,
7359  mx_float clip_gradient = -1) {
7360  return Operator("mp_nag_mom_update")
7361  .SetParam("lr", lr)
7362  .SetParam("momentum", momentum)
7363  .SetParam("wd", wd)
7364  .SetParam("rescale_grad", rescale_grad)
7365  .SetParam("clip_gradient", clip_gradient)
7366  .SetInput("weight", weight)
7367  .SetInput("grad", grad)
7368  .SetInput("mom", mom)
7369  .SetInput("weight32", weight32)
7370  .CreateSymbol(symbol_name);
7371 }
7372 
7426 inline Symbol rmsprop_update(const std::string& symbol_name,
7427  Symbol weight,
7428  Symbol grad,
7429  Symbol n,
7430  mx_float lr,
7431  mx_float gamma1 = 0.949999988,
7432  mx_float epsilon = 9.99999994e-09,
7433  mx_float wd = 0,
7434  mx_float rescale_grad = 1,
7435  mx_float clip_gradient = -1,
7436  mx_float clip_weights = -1) {
7437  return Operator("rmsprop_update")
7438  .SetParam("lr", lr)
7439  .SetParam("gamma1", gamma1)
7440  .SetParam("epsilon", epsilon)
7441  .SetParam("wd", wd)
7442  .SetParam("rescale_grad", rescale_grad)
7443  .SetParam("clip_gradient", clip_gradient)
7444  .SetParam("clip_weights", clip_weights)
7445  .SetInput("weight", weight)
7446  .SetInput("grad", grad)
7447  .SetInput("n", n)
7448  .CreateSymbol(symbol_name);
7449 }
7450 
7496 inline Symbol rmspropalex_update(const std::string& symbol_name,
7497  Symbol weight,
7498  Symbol grad,
7499  Symbol n,
7500  Symbol g,
7501  Symbol delta,
7502  mx_float lr,
7503  mx_float gamma1 = 0.949999988,
7504  mx_float gamma2 = 0.899999976,
7505  mx_float epsilon = 9.99999994e-09,
7506  mx_float wd = 0,
7507  mx_float rescale_grad = 1,
7508  mx_float clip_gradient = -1,
7509  mx_float clip_weights = -1) {
7510  return Operator("rmspropalex_update")
7511  .SetParam("lr", lr)
7512  .SetParam("gamma1", gamma1)
7513  .SetParam("gamma2", gamma2)
7514  .SetParam("epsilon", epsilon)
7515  .SetParam("wd", wd)
7516  .SetParam("rescale_grad", rescale_grad)
7517  .SetParam("clip_gradient", clip_gradient)
7518  .SetParam("clip_weights", clip_weights)
7519  .SetInput("weight", weight)
7520  .SetInput("grad", grad)
7521  .SetInput("n", n)
7522  .SetInput("g", g)
7523  .SetInput("delta", delta)
7524  .CreateSymbol(symbol_name);
7525 }
7526 
7566 inline Symbol ftrl_update(const std::string& symbol_name,
7567  Symbol weight,
7568  Symbol grad,
7569  Symbol z,
7570  Symbol n,
7571  mx_float lr,
7572  mx_float lamda1 = 0.00999999978,
7573  mx_float beta = 1,
7574  mx_float wd = 0,
7575  mx_float rescale_grad = 1,
7576  mx_float clip_gradient = -1) {
7577  return Operator("ftrl_update")
7578  .SetParam("lr", lr)
7579  .SetParam("lamda1", lamda1)
7580  .SetParam("beta", beta)
7581  .SetParam("wd", wd)
7582  .SetParam("rescale_grad", rescale_grad)
7583  .SetParam("clip_gradient", clip_gradient)
7584  .SetInput("weight", weight)
7585  .SetInput("grad", grad)
7586  .SetInput("z", z)
7587  .SetInput("n", n)
7588  .CreateSymbol(symbol_name);
7589 }
7590 
7662 inline Symbol SliceChannel(const std::string& symbol_name,
7663  Symbol data,
7664  int num_outputs,
7665  int axis = 1,
7666  bool squeeze_axis = false) {
7667  return Operator("SliceChannel")
7668  .SetParam("num_outputs", num_outputs)
7669  .SetParam("axis", axis)
7670  .SetParam("squeeze_axis", squeeze_axis)
7671  .SetInput("data", data)
7672  .CreateSymbol(symbol_name);
7673 }
7674 
7678 enum class PadMode {
7679  kConstant = 0,
7680  kEdge = 1,
7681  kReflect = 2
7682 };
7683 
7780 inline Symbol Pad(const std::string& symbol_name,
7781  Symbol data,
7782  PadMode mode,
7783  Shape pad_width,
7784  double constant_value = 0) {
7785  static const char *PadModeValues[] = {
7786  "constant",
7787  "edge",
7788  "reflect"
7789  };
7790  return Operator("Pad")
7791  .SetParam("mode", PadModeValues[int(mode)])
7792  .SetParam("pad_width", pad_width)
7793  .SetParam("constant_value", constant_value)
7794  .SetInput("data", data)
7795  .CreateSymbol(symbol_name);
7796 }
7797 
7848 inline Symbol InstanceNorm(const std::string& symbol_name,
7849  Symbol data,
7850  Symbol gamma,
7851  Symbol beta,
7852  mx_float eps = 0.00100000005) {
7853  return Operator("InstanceNorm")
7854  .SetParam("eps", eps)
7855  .SetInput("data", data)
7856  .SetInput("gamma", gamma)
7857  .SetInput("beta", beta)
7858  .CreateSymbol(symbol_name);
7859 }
7860 
7865  kAffine = 0,
7866  kWarp = 1
7867 };
7868 
7879 inline Symbol GridGenerator(const std::string& symbol_name,
7880  Symbol data,
7881  GridGeneratorTransformType transform_type,
7882  Shape target_shape = Shape(0,0)) {
7883  static const char *GridGeneratorTransformTypeValues[] = {
7884  "affine",
7885  "warp"
7886  };
7887  return Operator("GridGenerator")
7888  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
7889  .SetParam("target_shape", target_shape)
7890  .SetInput("data", data)
7891  .CreateSymbol(symbol_name);
7892 }
7893 
7897  kAvg = 0,
7898  kMax = 1,
7899  kSum = 2
7900 };
7901 
7905  kFull = 0,
7906  kValid = 1
7907 };
7908 
7960 inline Symbol Pooling_v1(const std::string& symbol_name,
7961  Symbol data,
7962  Shape kernel = Shape(),
7964  bool global_pool = false,
7966  Shape stride = Shape(),
7967  Shape pad = Shape()) {
7968  static const char *Pooling_v1PoolTypeValues[] = {
7969  "avg",
7970  "max",
7971  "sum"
7972  };
7973  static const char *Pooling_v1PoolingConventionValues[] = {
7974  "full",
7975  "valid"
7976  };
7977  return Operator("Pooling_v1")
7978  .SetParam("kernel", kernel)
7979  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
7980  .SetParam("global_pool", global_pool)
7981  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
7982  .SetParam("stride", stride)
7983  .SetParam("pad", pad)
7984  .SetInput("data", data)
7985  .CreateSymbol(symbol_name);
7986 }
7987 
7998  kNone = 0,
7999  kFastest = 1,
8000  kLimited_workspace = 2,
8001  kOff = 3
8002 };
8003 
8008  kNone = 0,
8009  kNCDHW = 1,
8010  kNCHW = 2,
8011  kNDHWC = 3,
8012  kNHWC = 4
8013 };
8014 
8045 inline Symbol Convolution_v1(const std::string& symbol_name,
8046  Symbol data,
8047  Symbol weight,
8048  Symbol bias,
8049  Shape kernel,
8050  uint32_t num_filter,
8051  Shape stride = Shape(),
8052  Shape dilate = Shape(),
8053  Shape pad = Shape(),
8054  uint32_t num_group = 1,
8055  uint64_t workspace = 1024,
8056  bool no_bias = false,
8058  bool cudnn_off = false,
8060  static const char *Convolution_v1CudnnTuneValues[] = {
8061  "None",
8062  "fastest",
8063  "limited_workspace",
8064  "off"
8065  };
8066  static const char *Convolution_v1LayoutValues[] = {
8067  "None",
8068  "NCDHW",
8069  "NCHW",
8070  "NDHWC",
8071  "NHWC"
8072  };
8073  return Operator("Convolution_v1")
8074  .SetParam("kernel", kernel)
8075  .SetParam("num_filter", num_filter)
8076  .SetParam("stride", stride)
8077  .SetParam("dilate", dilate)
8078  .SetParam("pad", pad)
8079  .SetParam("num_group", num_group)
8080  .SetParam("workspace", workspace)
8081  .SetParam("no_bias", no_bias)
8082  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
8083  .SetParam("cudnn_off", cudnn_off)
8084  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
8085  .SetInput("data", data)
8086  .SetInput("weight", weight)
8087  .SetInput("bias", bias)
8088  .CreateSymbol(symbol_name);
8089 }
8090 
8111 inline Symbol Crop(const std::string& symbol_name,
8112  const std::vector<Symbol>& data,
8113  int num_args,
8114  Shape offset = Shape(0,0),
8115  Shape h_w = Shape(0,0),
8116  bool center_crop = false) {
8117  return Operator("Crop")
8118  .SetParam("num_args", num_args)
8119  .SetParam("offset", offset)
8120  .SetParam("h_w", h_w)
8121  .SetParam("center_crop", center_crop)
8122 (data)
8123  .CreateSymbol(symbol_name);
8124 }
8125 
8202 inline Symbol SequenceReverse(const std::string& symbol_name,
8203  Symbol data,
8204  Symbol sequence_length,
8205  bool use_sequence_length = false,
8206  int axis = 0) {
8207  return Operator("SequenceReverse")
8208  .SetParam("use_sequence_length", use_sequence_length)
8209  .SetParam("axis", axis)
8210  .SetInput("data", data)
8211  .SetInput("sequence_length", sequence_length)
8212  .CreateSymbol(symbol_name);
8213 }
8214 
8218  kAffine = 0
8219 };
8220 
8224  kBilinear = 0
8225 };
8226 
8238 inline Symbol SpatialTransformer(const std::string& symbol_name,
8239  Symbol data,
8240  Symbol loc,
8241  SpatialTransformerTransformType transform_type,
8242  SpatialTransformerSamplerType sampler_type,
8243  Shape target_shape = Shape(0,0),
8244  dmlc::optional<bool> cudnn_off = dmlc::optional<bool>()) {
8245  static const char *SpatialTransformerTransformTypeValues[] = {
8246  "affine"
8247  };
8248  static const char *SpatialTransformerSamplerTypeValues[] = {
8249  "bilinear"
8250  };
8251  return Operator("SpatialTransformer")
8252  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
8253  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
8254  .SetParam("target_shape", target_shape)
8255  .SetParam("cudnn_off", cudnn_off)
8256  .SetInput("data", data)
8257  .SetInput("loc", loc)
8258  .CreateSymbol(symbol_name);
8259 }
8260 
8342 inline Symbol BilinearSampler(const std::string& symbol_name,
8343  Symbol data,
8344  Symbol grid,
8345  dmlc::optional<bool> cudnn_off = dmlc::optional<bool>()) {
8346  return Operator("BilinearSampler")
8347  .SetParam("cudnn_off", cudnn_off)
8348  .SetInput("data", data)
8349  .SetInput("grid", grid)
8350  .CreateSymbol(symbol_name);
8351 }
8352 
8409 inline Symbol ROIPooling(const std::string& symbol_name,
8410  Symbol data,
8411  Symbol rois,
8412  Shape pooled_size,
8413  mx_float spatial_scale) {
8414  return Operator("ROIPooling")
8415  .SetParam("pooled_size", pooled_size)
8416  .SetParam("spatial_scale", spatial_scale)
8417  .SetInput("data", data)
8418  .SetInput("rois", rois)
8419  .CreateSymbol(symbol_name);
8420 }
8421 
8477 inline Symbol SequenceLast(const std::string& symbol_name,
8478  Symbol data,
8479  Symbol sequence_length,
8480  bool use_sequence_length = false,
8481  int axis = 0) {
8482  return Operator("SequenceLast")
8483  .SetParam("use_sequence_length", use_sequence_length)
8484  .SetParam("axis", axis)
8485  .SetInput("data", data)
8486  .SetInput("sequence_length", sequence_length)
8487  .CreateSymbol(symbol_name);
8488 }
8489 
8493  kChannel = 0,
8494  kInstance = 1,
8495  kSpatial = 2
8496 };
8497 
8560 inline Symbol L2Normalization(const std::string& symbol_name,
8561  Symbol data,
8562  mx_float eps = 1.00000001e-10,
8564  static const char *L2NormalizationModeValues[] = {
8565  "channel",
8566  "instance",
8567  "spatial"
8568  };
8569  return Operator("L2Normalization")
8570  .SetParam("eps", eps)
8571  .SetParam("mode", L2NormalizationModeValues[int(mode)])
8572  .SetInput("data", data)
8573  .CreateSymbol(symbol_name);
8574 }
8575 
8581  kBatch = 0,
8582  kNull = 1,
8583  kValid = 2
8584 };
8585 
8620 inline Symbol MakeLoss(const std::string& symbol_name,
8621  Symbol data,
8622  mx_float grad_scale = 1,
8623  mx_float valid_thresh = 0,
8625  static const char *MakeLossNormalizationValues[] = {
8626  "batch",
8627  "null",
8628  "valid"
8629  };
8630  return Operator("MakeLoss")
8631  .SetParam("grad_scale", grad_scale)
8632  .SetParam("valid_thresh", valid_thresh)
8633  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
8634  .SetInput("data", data)
8635  .CreateSymbol(symbol_name);
8636 }
8637 
8653 inline Symbol SVMOutput(const std::string& symbol_name,
8654  Symbol data,
8655  Symbol label,
8656  mx_float margin = 1,
8657  mx_float regularization_coefficient = 1,
8658  bool use_linear = false) {
8659  return Operator("SVMOutput")
8660  .SetParam("margin", margin)
8661  .SetParam("regularization_coefficient", regularization_coefficient)
8662  .SetParam("use_linear", use_linear)
8663  .SetInput("data", data)
8664  .SetInput("label", label)
8665  .CreateSymbol(symbol_name);
8666 }
8667 
8717 inline Symbol Correlation(const std::string& symbol_name,
8718  Symbol data1,
8719  Symbol data2,
8720  uint32_t kernel_size = 1,
8721  uint32_t max_displacement = 1,
8722  uint32_t stride1 = 1,
8723  uint32_t stride2 = 1,
8724  uint32_t pad_size = 0,
8725  bool is_multiply = true) {
8726  return Operator("Correlation")
8727  .SetParam("kernel_size", kernel_size)
8728  .SetParam("max_displacement", max_displacement)
8729  .SetParam("stride1", stride1)
8730  .SetParam("stride2", stride2)
8731  .SetParam("pad_size", pad_size)
8732  .SetParam("is_multiply", is_multiply)
8733  .SetInput("data1", data1)
8734  .SetInput("data2", data2)
8735  .CreateSymbol(symbol_name);
8736 }
8737 
8816 inline Symbol SequenceMask(const std::string& symbol_name,
8817  Symbol data,
8818  Symbol sequence_length,
8819  bool use_sequence_length = false,
8820  mx_float value = 0,
8821  int axis = 0) {
8822  return Operator("SequenceMask")
8823  .SetParam("use_sequence_length", use_sequence_length)
8824  .SetParam("value", value)
8825  .SetParam("axis", axis)
8826  .SetInput("data", data)
8827  .SetInput("sequence_length", sequence_length)
8828  .CreateSymbol(symbol_name);
8829 }
8830 
8840 inline Symbol fill_element_0index(const std::string& symbol_name,
8841  Symbol lhs,
8842  Symbol mhs,
8843  Symbol rhs) {
8844  return Operator("fill_element_0index")
8845  .SetInput("lhs", lhs)
8846  .SetInput("mhs", mhs)
8847  .SetInput("rhs", rhs)
8848  .CreateSymbol(symbol_name);
8849 }
8850 
8890 inline Symbol khatri_rao(const std::vector<Symbol>& args) {
8891  return Operator("khatri_rao")
8892 (args)
8893  .CreateSymbol();
8894 }
8895 
8906  bool init_output = true) {
8907  return Operator("all_finite")
8908  .SetParam("init_output", init_output)
8909  .SetInput("data", data)
8910  .CreateSymbol();
8911 }
8912 
8923 inline Symbol multi_all_finite(const std::vector<Symbol>& data,
8924  int num_arrays = 1,
8925  bool init_output = true) {
8926  return Operator("multi_all_finite")
8927  .SetParam("num_arrays", num_arrays)
8928  .SetParam("init_output", init_output)
8929 (data)
8930  .CreateSymbol();
8931 }
8932 
8947 inline Symbol Custom(const std::vector<Symbol>& data,
8948  const std::string& op_type) {
8949  return Operator("Custom")
8950 (data)
8951  .CreateSymbol();
8952 }
8953 
8976  Symbol rhs) {
8977  return Operator("broadcast_power")
8978  .SetInput("lhs", lhs)
8979  .SetInput("rhs", rhs)
8980  .CreateSymbol();
8981 }
8982 
9007  Symbol rhs) {
9008  return Operator("broadcast_maximum")
9009  .SetInput("lhs", lhs)
9010  .SetInput("rhs", rhs)
9011  .CreateSymbol();
9012 }
9013 
9038  Symbol rhs) {
9039  return Operator("broadcast_minimum")
9040  .SetInput("lhs", lhs)
9041  .SetInput("rhs", rhs)
9042  .CreateSymbol();
9043 }
9044 
9075  Symbol rhs) {
9076  return Operator("broadcast_hypot")
9077  .SetInput("lhs", lhs)
9078  .SetInput("rhs", rhs)
9079  .CreateSymbol();
9080 }
9081 
9155 inline Symbol Reshape(Symbol data,
9156  Shape shape = Shape(),
9157  bool reverse = false,
9158  Shape target_shape = Shape(),
9159  bool keep_highest = false) {
9160  return Operator("Reshape")
9161  .SetParam("shape", shape)
9162  .SetParam("reverse", reverse)
9163  .SetParam("target_shape", target_shape)
9164  .SetParam("keep_highest", keep_highest)
9165  .SetInput("data", data)
9166  .CreateSymbol();
9167 }
9168 
9201 inline Symbol Flatten(Symbol data) {
9202  return Operator("Flatten")
9203  .SetInput("data", data)
9204  .CreateSymbol();
9205 }
9206 
9243  Shape axes = Shape()) {
9244  return Operator("transpose")
9245  .SetParam("axes", axes)
9246  .SetInput("data", data)
9247  .CreateSymbol();
9248 }
9249 
9265  int axis) {
9266  return Operator("expand_dims")
9267  .SetParam("axis", axis)
9268  .SetInput("data", data)
9269  .CreateSymbol();
9270 }
9271 
9326 inline Symbol slice(Symbol data,
9327  Shape begin,
9328  Shape end,
9329  Shape step = Shape()) {
9330  return Operator("slice")
9331  .SetParam("begin", begin)
9332  .SetParam("end", end)
9333  .SetParam("step", step)
9334  .SetInput("data", data)
9335  .CreateSymbol();
9336 }
9337 
9370  int axis,
9371  int begin,
9372  dmlc::optional<int> end) {
9373  return Operator("slice_axis")
9374  .SetParam("axis", axis)
9375  .SetParam("begin", begin)
9376  .SetParam("end", end)
9377  .SetInput("data", data)
9378  .CreateSymbol();
9379 }
9380 
9442  Symbol shape_like,
9443  Shape axes = Shape()) {
9444  return Operator("slice_like")
9445  .SetParam("axes", axes)
9446  .SetInput("data", data)
9447  .SetInput("shape_like", shape_like)
9448  .CreateSymbol();
9449 }
9450 
9484 inline Symbol clip(Symbol data,
9485  mx_float a_min,
9486  mx_float a_max) {
9487  return Operator("clip")
9488  .SetParam("a_min", a_min)
9489  .SetParam("a_max", a_max)
9490  .SetInput("data", data)
9491  .CreateSymbol();
9492 }
9493 
9527 inline Symbol repeat(Symbol data,
9528  int repeats,
9529  dmlc::optional<int> axis = dmlc::optional<int>()) {
9530  return Operator("repeat")
9531  .SetParam("repeats", repeats)
9532  .SetParam("axis", axis)
9533  .SetInput("data", data)
9534  .CreateSymbol();
9535 }
9536 
9581 inline Symbol tile(Symbol data,
9582  Shape reps) {
9583  return Operator("tile")
9584  .SetParam("reps", reps)
9585  .SetInput("data", data)
9586  .CreateSymbol();
9587 }
9588 
9611 inline Symbol reverse(Symbol data,
9612  Shape axis) {
9613  return Operator("reverse")
9614  .SetParam("axis", axis)
9615  .SetInput("data", data)
9616  .CreateSymbol();
9617 }
9618 
9641 inline Symbol stack(const std::vector<Symbol>& data,
9642  int num_args,
9643  int axis = 0) {
9644  return Operator("stack")
9645  .SetParam("num_args", num_args)
9646  .SetParam("axis", axis)
9647 (data)
9648  .CreateSymbol();
9649 }
9650 
9672 inline Symbol squeeze(const std::vector<Symbol>& data,
9673  dmlc::optional<Shape> axis = dmlc::optional<Shape>()) {
9674  return Operator("squeeze")
9675  .SetParam("axis", axis)
9676 (data)
9677  .CreateSymbol();
9678 }
9679 
9721  int block_size) {
9722  return Operator("depth_to_space")
9723  .SetParam("block_size", block_size)
9724  .SetInput("data", data)
9725  .CreateSymbol();
9726 }
9727 
9771  int block_size) {
9772  return Operator("space_to_depth")
9773  .SetParam("block_size", block_size)
9774  .SetInput("data", data)
9775  .CreateSymbol();
9776 }
9777 
9800 inline Symbol zeros_like(Symbol data) {
9801  return Operator("zeros_like")
9802  .SetInput("data", data)
9803  .CreateSymbol();
9804 }
9805 
9822 inline Symbol ones_like(Symbol data) {
9823  return Operator("ones_like")
9824  .SetInput("data", data)
9825  .CreateSymbol();
9826 }
9827 
9849 inline Symbol add_n(const std::vector<Symbol>& args) {
9850  return Operator("add_n")
9851 (args)
9852  .CreateSymbol();
9853 }
9854 
9885 inline Symbol argmax(Symbol data,
9886  dmlc::optional<int> axis = dmlc::optional<int>(),
9887  bool keepdims = false) {
9888  return Operator("argmax")
9889  .SetParam("axis", axis)
9890  .SetParam("keepdims", keepdims)
9891  .SetInput("data", data)
9892  .CreateSymbol();
9893 }
9894 
9925 inline Symbol argmin(Symbol data,
9926  dmlc::optional<int> axis = dmlc::optional<int>(),
9927  bool keepdims = false) {
9928  return Operator("argmin")
9929  .SetParam("axis", axis)
9930  .SetParam("keepdims", keepdims)
9931  .SetInput("data", data)
9932  .CreateSymbol();
9933 }
9934 
9957  return Operator("argmax_channel")
9958  .SetInput("data", data)
9959  .CreateSymbol();
9960 }
9961 
10017 inline Symbol pick(Symbol data,
10018  Symbol index,
10019  dmlc::optional<int> axis = dmlc::optional<int>(-1),
10020  bool keepdims = false,
10021  PickMode mode = PickMode::kClip) {
10022  static const char *PickModeValues[] = {
10023  "clip",
10024  "wrap"
10025  };
10026  return Operator("pick")
10027  .SetParam("axis", axis)
10028  .SetParam("keepdims", keepdims)
10029  .SetParam("mode", PickModeValues[int(mode)])
10030  .SetInput("data", data)
10031  .SetInput("index", index)
10032  .CreateSymbol();
10033 }
10034 
10092 inline Symbol dot(Symbol lhs,
10093  Symbol rhs,
10094  bool transpose_a = false,
10095  bool transpose_b = false,
10096  DotForwardStype forward_stype = DotForwardStype::kNone) {
10097  static const char *DotForwardStypeValues[] = {
10098  "None",
10099  "csr",
10100  "default",
10101  "row_sparse"
10102  };
10103  return Operator("dot")
10104  .SetParam("transpose_a", transpose_a)
10105  .SetParam("transpose_b", transpose_b)
10106  .SetParam("forward_stype", DotForwardStypeValues[int(forward_stype)])
10107  .SetInput("lhs", lhs)
10108  .SetInput("rhs", rhs)
10109  .CreateSymbol();
10110 }
10111 
10137  Symbol rhs,
10138  bool transpose_a = false,
10139  bool transpose_b = false,
10141  static const char *Batch_dotForwardStypeValues[] = {
10142  "None",
10143  "csr",
10144  "default",
10145  "row_sparse"
10146  };
10147  return Operator("batch_dot")
10148  .SetParam("transpose_a", transpose_a)
10149  .SetParam("transpose_b", transpose_b)
10150  .SetParam("forward_stype", Batch_dotForwardStypeValues[int(forward_stype)])
10151  .SetInput("lhs", lhs)
10152  .SetInput("rhs", rhs)
10153  .CreateSymbol();
10154 }
10155 
10188  Symbol rhs) {
10189  return Operator("broadcast_add")
10190  .SetInput("lhs", lhs)
10191  .SetInput("rhs", rhs)
10192  .CreateSymbol();
10193 }
10194 
10227  Symbol rhs) {
10228  return Operator("broadcast_sub")
10229  .SetInput("lhs", lhs)
10230  .SetInput("rhs", rhs)
10231  .CreateSymbol();
10232 }
10233 
10260  Symbol rhs) {
10261  return Operator("broadcast_mul")
10262  .SetInput("lhs", lhs)
10263  .SetInput("rhs", rhs)
10264  .CreateSymbol();
10265 }
10266 
10293  Symbol rhs) {
10294  return Operator("broadcast_div")
10295  .SetInput("lhs", lhs)
10296  .SetInput("rhs", rhs)
10297  .CreateSymbol();
10298 }
10299 
10322  Symbol rhs) {
10323  return Operator("broadcast_mod")
10324  .SetInput("lhs", lhs)
10325  .SetInput("rhs", rhs)
10326  .CreateSymbol();
10327 }
10328 
10347 inline Symbol relu(Symbol data) {
10348  return Operator("relu")
10349  .SetInput("data", data)
10350  .CreateSymbol();
10351 }
10352 
10367 inline Symbol sigmoid(Symbol data) {
10368  return Operator("sigmoid")
10369  .SetInput("data", data)
10370  .CreateSymbol();
10371 }
10372 
10388  mx_float alpha = 0.200000003,
10389  mx_float beta = 0.5) {
10390  return Operator("hard_sigmoid")
10391  .SetParam("alpha", alpha)
10392  .SetParam("beta", beta)
10393  .SetInput("data", data)
10394  .CreateSymbol();
10395 }
10396 
10411 inline Symbol softsign(Symbol data) {
10412  return Operator("softsign")
10413  .SetInput("data", data)
10414  .CreateSymbol();
10415 }
10416 
10449 inline Symbol BlockGrad(Symbol data) {
10450  return Operator("BlockGrad")
10451  .SetInput("data", data)
10452  .CreateSymbol();
10453 }
10454 
10483 inline Symbol make_loss(Symbol data) {
10484  return Operator("make_loss")
10485  .SetInput("data", data)
10486  .CreateSymbol();
10487 }
10488 
10523  Symbol rhs) {
10524  return Operator("reshape_like")
10525  .SetInput("lhs", lhs)
10526  .SetInput("rhs", rhs)
10527  .CreateSymbol();
10528 }
10529 
10548  dmlc::optional<int> lhs_begin = dmlc::optional<int>(),
10549  dmlc::optional<int> lhs_end = dmlc::optional<int>(),
10550  dmlc::optional<int> rhs_begin = dmlc::optional<int>(),
10551  dmlc::optional<int> rhs_end = dmlc::optional<int>()) {
10552  return Operator("shape_array")
10553  .SetParam("lhs_begin", lhs_begin)
10554  .SetParam("lhs_end", lhs_end)
10555  .SetParam("rhs_begin", rhs_begin)
10556  .SetParam("rhs_end", rhs_end)
10557  .SetInput("data", data)
10558  .CreateSymbol();
10559 }
10560 
10574 inline Symbol size_array(Symbol data) {
10575  return Operator("size_array")
10576  .SetInput("data", data)
10577  .CreateSymbol();
10578 }
10579 
10598 inline Symbol Cast(Symbol data,
10599  CastDtype dtype) {
10600  static const char *CastDtypeValues[] = {
10601  "float16",
10602  "float32",
10603  "float64",
10604  "int32",
10605  "int64",
10606  "int8",
10607  "uint8"
10608  };
10609  return Operator("Cast")
10610  .SetParam("dtype", CastDtypeValues[int(dtype)])
10611  .SetInput("data", data)
10612  .CreateSymbol();
10613 }
10614 
10628 inline Symbol negative(Symbol data) {
10629  return Operator("negative")
10630  .SetInput("data", data)
10631  .CreateSymbol();
10632 }
10633 
10649 inline Symbol reciprocal(Symbol data) {
10650  return Operator("reciprocal")
10651  .SetInput("data", data)
10652  .CreateSymbol();
10653 }
10654 
10674 inline Symbol abs(Symbol data) {
10675  return Operator("abs")
10676  .SetInput("data", data)
10677  .CreateSymbol();
10678 }
10679 
10699 inline Symbol sign(Symbol data) {
10700  return Operator("sign")
10701  .SetInput("data", data)
10702  .CreateSymbol();
10703 }
10704 
10724 inline Symbol round(Symbol data) {
10725  return Operator("round")
10726  .SetInput("data", data)
10727  .CreateSymbol();
10728 }
10729 
10753 inline Symbol rint(Symbol data) {
10754  return Operator("rint")
10755  .SetInput("data", data)
10756  .CreateSymbol();
10757 }
10758 
10780 inline Symbol ceil(Symbol data) {
10781  return Operator("ceil")
10782  .SetInput("data", data)
10783  .CreateSymbol();
10784 }
10785 
10807 inline Symbol floor(Symbol data) {
10808  return Operator("floor")
10809  .SetInput("data", data)
10810  .CreateSymbol();
10811 }
10812 
10835 inline Symbol trunc(Symbol data) {
10836  return Operator("trunc")
10837  .SetInput("data", data)
10838  .CreateSymbol();
10839 }
10840 
10861 inline Symbol fix(Symbol data) {
10862  return Operator("fix")
10863  .SetInput("data", data)
10864  .CreateSymbol();
10865 }
10866 
10889 inline Symbol square(Symbol data) {
10890  return Operator("square")
10891  .SetInput("data", data)
10892  .CreateSymbol();
10893 }
10894 
10917 inline Symbol sqrt(Symbol data) {
10918  return Operator("sqrt")
10919  .SetInput("data", data)
10920  .CreateSymbol();
10921 }
10922 
10941 inline Symbol rsqrt(Symbol data) {
10942  return Operator("rsqrt")
10943  .SetInput("data", data)
10944  .CreateSymbol();
10945 }
10946 
10969 inline Symbol cbrt(Symbol data) {
10970  return Operator("cbrt")
10971  .SetInput("data", data)
10972  .CreateSymbol();
10973 }
10974 
10988 inline Symbol erf(Symbol data) {
10989  return Operator("erf")
10990  .SetInput("data", data)
10991  .CreateSymbol();
10992 }
10993 
11007 inline Symbol erfinv(Symbol data) {
11008  return Operator("erfinv")
11009  .SetInput("data", data)
11010  .CreateSymbol();
11011 }
11012 
11029 inline Symbol rcbrt(Symbol data) {
11030  return Operator("rcbrt")
11031  .SetInput("data", data)
11032  .CreateSymbol();
11033 }
11034 
11053 inline Symbol exp(Symbol data) {
11054  return Operator("exp")
11055  .SetInput("data", data)
11056  .CreateSymbol();
11057 }
11058 
11072 inline Symbol log(Symbol data) {
11073  return Operator("log")
11074  .SetInput("data", data)
11075  .CreateSymbol();
11076 }
11077 
11091 inline Symbol log10(Symbol data) {
11092  return Operator("log10")
11093  .SetInput("data", data)
11094  .CreateSymbol();
11095 }
11096 
11110 inline Symbol log2(Symbol data) {
11111  return Operator("log2")
11112  .SetInput("data", data)
11113  .CreateSymbol();
11114 }
11115 
11134 inline Symbol log1p(Symbol data) {
11135  return Operator("log1p")
11136  .SetInput("data", data)
11137  .CreateSymbol();
11138 }
11139 
11157 inline Symbol expm1(Symbol data) {
11158  return Operator("expm1")
11159  .SetInput("data", data)
11160  .CreateSymbol();
11161 }
11162 
11173 inline Symbol gamma(Symbol data) {
11174  return Operator("gamma")
11175  .SetInput("data", data)
11176  .CreateSymbol();
11177 }
11178 
11189 inline Symbol gammaln(Symbol data) {
11190  return Operator("gammaln")
11191  .SetInput("data", data)
11192  .CreateSymbol();
11193 }
11194 
11205 inline Symbol logical_not(Symbol data) {
11206  return Operator("logical_not")
11207  .SetInput("data", data)
11208  .CreateSymbol();
11209 }
11210 
11222 inline Symbol amp_cast(Symbol data,
11223  Amp_castDtype dtype) {
11224  static const char *Amp_castDtypeValues[] = {
11225  "float16",
11226  "float32",
11227  "float64",
11228  "int32",
11229  "int64",
11230  "int8",
11231  "uint8"
11232  };
11233  return Operator("amp_cast")
11234  .SetParam("dtype", Amp_castDtypeValues[int(dtype)])
11235  .SetInput("data", data)
11236  .CreateSymbol();
11237 }
11238 
11251 inline Symbol amp_multicast(const std::vector<Symbol>& data,
11252  int num_outputs) {
11253  return Operator("amp_multicast")
11254  .SetParam("num_outputs", num_outputs)
11255 (data)
11256  .CreateSymbol();
11257 }
11258 
11301 inline Symbol topk(Symbol data,
11302  dmlc::optional<int> axis = dmlc::optional<int>(-1),
11303  int k = 1,
11304  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
11305  bool is_ascend = false,
11306  TopkDtype dtype = TopkDtype::kFloat32) {
11307  static const char *TopkRetTypValues[] = {
11308  "both",
11309  "indices",
11310  "mask",
11311  "value"
11312  };
11313  static const char *TopkDtypeValues[] = {
11314  "float16",
11315  "float32",
11316  "float64",
11317  "int32",
11318  "uint8"
11319  };
11320  return Operator("topk")
11321  .SetParam("axis", axis)
11322  .SetParam("k", k)
11323  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
11324  .SetParam("is_ascend", is_ascend)
11325  .SetParam("dtype", TopkDtypeValues[int(dtype)])
11326  .SetInput("data", data)
11327  .CreateSymbol();
11328 }
11329 
11361 inline Symbol sort(Symbol data,
11362  dmlc::optional<int> axis = dmlc::optional<int>(-1),
11363  bool is_ascend = true) {
11364  return Operator("sort")
11365  .SetParam("axis", axis)
11366  .SetParam("is_ascend", is_ascend)
11367  .SetInput("data", data)
11368  .CreateSymbol();
11369 }
11370 
11402 inline Symbol argsort(Symbol data,
11403  dmlc::optional<int> axis = dmlc::optional<int>(-1),
11404  bool is_ascend = true,
11406  static const char *ArgsortDtypeValues[] = {
11407  "float16",
11408  "float32",
11409  "float64",
11410  "int32",
11411  "uint8"
11412  };
11413  return Operator("argsort")
11414  .SetParam("axis", axis)
11415  .SetParam("is_ascend", is_ascend)
11416  .SetParam("dtype", ArgsortDtypeValues[int(dtype)])
11417  .SetInput("data", data)
11418  .CreateSymbol();
11419 }
11420 
11440  Symbol rhs) {
11441  return Operator("elemwise_add")
11442  .SetInput("lhs", lhs)
11443  .SetInput("rhs", rhs)
11444  .CreateSymbol();
11445 }
11446 
11466  Symbol rhs) {
11467  return Operator("elemwise_sub")
11468  .SetInput("lhs", lhs)
11469  .SetInput("rhs", rhs)
11470  .CreateSymbol();
11471 }
11472 
11491  Symbol rhs) {
11492  return Operator("elemwise_mul")
11493  .SetInput("lhs", lhs)
11494  .SetInput("rhs", rhs)
11495  .CreateSymbol();
11496 }
11497 
11509  Symbol rhs) {
11510  return Operator("elemwise_div")
11511  .SetInput("lhs", lhs)
11512  .SetInput("rhs", rhs)
11513  .CreateSymbol();
11514 }
11515 
11579  Symbol weight,
11580  int input_dim,
11581  int output_dim,
11583  bool sparse_grad = false) {
11584  static const char *EmbeddingDtypeValues[] = {
11585  "float16",
11586  "float32",
11587  "float64",
11588  "int32",
11589  "int64",
11590  "int8",
11591  "uint8"
11592  };
11593  return Operator("Embedding")
11594  .SetParam("input_dim", input_dim)
11595  .SetParam("output_dim", output_dim)
11596  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
11597  .SetParam("sparse_grad", sparse_grad)
11598  .SetInput("data", data)
11599  .SetInput("weight", weight)
11600  .CreateSymbol();
11601 }
11602 
11664 inline Symbol take(Symbol a,
11665  Symbol indices,
11666  int axis = 0,
11667  TakeMode mode = TakeMode::kClip) {
11668  static const char *TakeModeValues[] = {
11669  "clip",
11670  "raise",
11671  "wrap"
11672  };
11673  return Operator("take")
11674  .SetParam("axis", axis)
11675  .SetParam("mode", TakeModeValues[int(mode)])
11676  .SetInput("a", a)
11677  .SetInput("indices", indices)
11678  .CreateSymbol();
11679 }
11680 
11709  Symbol indices) {
11710  return Operator("batch_take")
11711  .SetInput("a", a)
11712  .SetInput("indices", indices)
11713  .CreateSymbol();
11714 }
11715 
11759 inline Symbol one_hot(Symbol indices,
11760  int depth,
11761  double on_value = 1,
11762  double off_value = 0,
11764  static const char *One_hotDtypeValues[] = {
11765  "float16",
11766  "float32",
11767  "float64",
11768  "int32",
11769  "int64",
11770  "int8",
11771  "uint8"
11772  };
11773  return Operator("one_hot")
11774  .SetParam("depth", depth)
11775  .SetParam("on_value", on_value)
11776  .SetParam("off_value", off_value)
11777  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
11778  .SetInput("indices", indices)
11779  .CreateSymbol();
11780 }
11781 
11813  Symbol indices) {
11814  return Operator("gather_nd")
11815  .SetInput("data", data)
11816  .SetInput("indices", indices)
11817  .CreateSymbol();
11818 }
11819 
11871  Symbol indices,
11872  Shape shape) {
11873  return Operator("scatter_nd")
11874  .SetParam("shape", shape)
11875  .SetInput("data", data)
11876  .SetInput("indices", indices)
11877  .CreateSymbol();
11878 }
11879 
11902  Symbol rhs) {
11903  return Operator("broadcast_equal")
11904  .SetInput("lhs", lhs)
11905  .SetInput("rhs", rhs)
11906  .CreateSymbol();
11907 }
11908 
11931  Symbol rhs) {
11932  return Operator("broadcast_not_equal")
11933  .SetInput("lhs", lhs)
11934  .SetInput("rhs", rhs)
11935  .CreateSymbol();
11936 }
11937 
11960  Symbol rhs) {
11961  return Operator("broadcast_greater")
11962  .SetInput("lhs", lhs)
11963  .SetInput("rhs", rhs)
11964  .CreateSymbol();
11965 }
11966 
11989  Symbol rhs) {
11990  return Operator("broadcast_greater_equal")
11991  .SetInput("lhs", lhs)
11992  .SetInput("rhs", rhs)
11993  .CreateSymbol();
11994 }
11995 
12018  Symbol rhs) {
12019  return Operator("broadcast_lesser")
12020  .SetInput("lhs", lhs)
12021  .SetInput("rhs", rhs)
12022  .CreateSymbol();
12023 }
12024 
12047  Symbol rhs) {
12048  return Operator("broadcast_lesser_equal")
12049  .SetInput("lhs", lhs)
12050  .SetInput("rhs", rhs)
12051  .CreateSymbol();
12052 }
12053 
12076  Symbol rhs) {
12077  return Operator("broadcast_logical_and")
12078  .SetInput("lhs", lhs)
12079  .SetInput("rhs", rhs)
12080  .CreateSymbol();
12081 }
12082 
12105  Symbol rhs) {
12106  return Operator("broadcast_logical_or")
12107  .SetInput("lhs", lhs)
12108  .SetInput("rhs", rhs)
12109  .CreateSymbol();
12110 }
12111 
12134  Symbol rhs) {
12135  return Operator("broadcast_logical_xor")
12136  .SetInput("lhs", lhs)
12137  .SetInput("rhs", rhs)
12138  .CreateSymbol();
12139 }
12140 
12204 inline Symbol diag(Symbol data,
12205  int k = 0,
12206  int axis1 = 0,
12207  int axis2 = 1) {
12208  return Operator("diag")
12209  .SetParam("k", k)
12210  .SetParam("axis1", axis1)
12211  .SetParam("axis2", axis2)
12212  .SetInput("data", data)
12213  .CreateSymbol();
12214 }
12215 
12250 inline Symbol where(Symbol condition,
12251  Symbol x,
12252  Symbol y) {
12253  return Operator("where")
12254  .SetInput("condition", condition)
12255  .SetInput("x", x)
12256  .SetInput("y", y)
12257  .CreateSymbol();
12258 }
12259 
12286  mx_float scalar) {
12287  return Operator("smooth_l1")
12288  .SetParam("scalar", scalar)
12289  .SetInput("data", data)
12290  .CreateSymbol();
12291 }
12292 
12338  Cast_storageStype stype) {
12339  static const char *Cast_storageStypeValues[] = {
12340  "csr",
12341  "default",
12342  "row_sparse"
12343  };
12344  return Operator("cast_storage")
12345  .SetParam("stype", Cast_storageStypeValues[int(stype)])
12346  .SetInput("data", data)
12347  .CreateSymbol();
12348 }
12349 
12407 inline Symbol sum(Symbol data,
12408  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
12409  bool keepdims = false,
12410  bool exclude = false) {
12411  return Operator("sum")
12412  .SetParam("axis", axis)
12413  .SetParam("keepdims", keepdims)
12414  .SetParam("exclude", exclude)
12415  .SetInput("data", data)
12416  .CreateSymbol();
12417 }
12418 
12442 inline Symbol mean(Symbol data,
12443  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
12444  bool keepdims = false,
12445  bool exclude = false) {
12446  return Operator("mean")
12447  .SetParam("axis", axis)
12448  .SetParam("keepdims", keepdims)
12449  .SetParam("exclude", exclude)
12450  .SetInput("data", data)
12451  .CreateSymbol();
12452 }
12453 
12477 inline Symbol prod(Symbol data,
12478  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
12479  bool keepdims = false,
12480  bool exclude = false) {
12481  return Operator("prod")
12482  .SetParam("axis", axis)
12483  .SetParam("keepdims", keepdims)
12484  .SetParam("exclude", exclude)
12485  .SetInput("data", data)
12486  .CreateSymbol();
12487 }
12488 
12514 inline Symbol nansum(Symbol data,
12515  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
12516  bool keepdims = false,
12517  bool exclude = false) {
12518  return Operator("nansum")
12519  .SetParam("axis", axis)
12520  .SetParam("keepdims", keepdims)
12521  .SetParam("exclude", exclude)
12522  .SetInput("data", data)
12523  .CreateSymbol();
12524 }
12525 
12551 inline Symbol nanprod(Symbol data,
12552  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
12553  bool keepdims = false,
12554  bool exclude = false) {
12555  return Operator("nanprod")
12556  .SetParam("axis", axis)
12557  .SetParam("keepdims", keepdims)
12558  .SetParam("exclude", exclude)
12559  .SetInput("data", data)
12560  .CreateSymbol();
12561 }
12562 
12586 inline Symbol max(Symbol data,
12587  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
12588  bool keepdims = false,
12589  bool exclude = false) {
12590  return Operator("max")
12591  .SetParam("axis", axis)
12592  .SetParam("keepdims", keepdims)
12593  .SetParam("exclude", exclude)
12594  .SetInput("data", data)
12595  .CreateSymbol();
12596 }
12597 
12621 inline Symbol min(Symbol data,
12622  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
12623  bool keepdims = false,
12624  bool exclude = false) {
12625  return Operator("min")
12626  .SetParam("axis", axis)
12627  .SetParam("keepdims", keepdims)
12628  .SetParam("exclude", exclude)
12629  .SetInput("data", data)
12630  .CreateSymbol();
12631 }
12632 
12662  Shape axis = Shape(),
12663  Shape size = Shape()) {
12664  return Operator("broadcast_axis")
12665  .SetParam("axis", axis)
12666  .SetParam("size", size)
12667  .SetInput("data", data)
12668  .CreateSymbol();
12669 }
12670 
12699  Shape shape = Shape()) {
12700  return Operator("broadcast_to")
12701  .SetParam("shape", shape)
12702  .SetInput("data", data)
12703  .CreateSymbol();
12704 }
12705 
12734  Symbol rhs,
12735  dmlc::optional<Shape> lhs_axes = dmlc::optional<Shape>(),
12736  dmlc::optional<Shape> rhs_axes = dmlc::optional<Shape>()) {
12737  return Operator("broadcast_like")
12738  .SetParam("lhs_axes", lhs_axes)
12739  .SetParam("rhs_axes", rhs_axes)
12740  .SetInput("lhs", lhs)
12741  .SetInput("rhs", rhs)
12742  .CreateSymbol();
12743 }
12744 
12788 inline Symbol norm(Symbol data,
12789  int ord = 2,
12790  dmlc::optional<Shape> axis = dmlc::optional<Shape>(),
12791  NormOutDtype out_dtype = NormOutDtype::kNone,
12792  bool keepdims = false) {
12793  static const char *NormOutDtypeValues[] = {
12794  "None",
12795  "float16",
12796  "float32",
12797  "float64",
12798  "int32",
12799  "int64",
12800  "int8"
12801  };
12802  return Operator("norm")
12803  .SetParam("ord", ord)
12804  .SetParam("axis", axis)
12805  .SetParam("out_dtype", NormOutDtypeValues[int(out_dtype)])
12806  .SetParam("keepdims", keepdims)
12807  .SetInput("data", data)
12808  .CreateSymbol();
12809 }
12810 
12831 inline Symbol sin(Symbol data) {
12832  return Operator("sin")
12833  .SetInput("data", data)
12834  .CreateSymbol();
12835 }
12836 
12853 inline Symbol cos(Symbol data) {
12854  return Operator("cos")
12855  .SetInput("data", data)
12856  .CreateSymbol();
12857 }
12858 
12879 inline Symbol tan(Symbol data) {
12880  return Operator("tan")
12881  .SetInput("data", data)
12882  .CreateSymbol();
12883 }
12884 
12906 inline Symbol arcsin(Symbol data) {
12907  return Operator("arcsin")
12908  .SetInput("data", data)
12909  .CreateSymbol();
12910 }
12911 
12929 inline Symbol arccos(Symbol data) {
12930  return Operator("arccos")
12931  .SetInput("data", data)
12932  .CreateSymbol();
12933 }
12934 
12955 inline Symbol arctan(Symbol data) {
12956  return Operator("arctan")
12957  .SetInput("data", data)
12958  .CreateSymbol();
12959 }
12960 
12979 inline Symbol degrees(Symbol data) {
12980  return Operator("degrees")
12981  .SetInput("data", data)
12982  .CreateSymbol();
12983 }
12984 
13003 inline Symbol radians(Symbol data) {
13004  return Operator("radians")
13005  .SetInput("data", data)
13006  .CreateSymbol();
13007 }
13008 
13027 inline Symbol sinh(Symbol data) {
13028  return Operator("sinh")
13029  .SetInput("data", data)
13030  .CreateSymbol();
13031 }
13032 
13047 inline Symbol cosh(Symbol data) {
13048  return Operator("cosh")
13049  .SetInput("data", data)
13050  .CreateSymbol();
13051 }
13052 
13071 inline Symbol tanh(Symbol data) {
13072  return Operator("tanh")
13073  .SetInput("data", data)
13074  .CreateSymbol();
13075 }
13076 
13093 inline Symbol arcsinh(Symbol data) {
13094  return Operator("arcsinh")
13095  .SetInput("data", data)
13096  .CreateSymbol();
13097 }
13098 
13111 inline Symbol arccosh(Symbol data) {
13112  return Operator("arccosh")
13113  .SetInput("data", data)
13114  .CreateSymbol();
13115 }
13116 
13133 inline Symbol arctanh(Symbol data) {
13134  return Operator("arctanh")
13135  .SetInput("data", data)
13136  .CreateSymbol();
13137 }
13138 
13208 inline Symbol Pooling(Symbol data,
13209  Shape kernel = Shape(),
13211  bool global_pool = false,
13212  bool cudnn_off = false,
13214  Shape stride = Shape(),
13215  Shape pad = Shape(),
13216  dmlc::optional<int> p_value = dmlc::optional<int>(),
13217  dmlc::optional<bool> count_include_pad = dmlc::optional<bool>(),
13219  static const char *PoolingPoolTypeValues[] = {
13220  "avg",
13221  "lp",
13222  "max",
13223  "sum"
13224  };
13225  static const char *PoolingPoolingConventionValues[] = {
13226  "full",
13227  "same",
13228  "valid"
13229  };
13230  static const char *PoolingLayoutValues[] = {
13231  "None",
13232  "NCDHW",
13233  "NCHW",
13234  "NCW",
13235  "NDHWC",
13236  "NHWC",
13237  "NWC"
13238  };
13239  return Operator("Pooling")
13240  .SetParam("kernel", kernel)
13241  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
13242  .SetParam("global_pool", global_pool)
13243  .SetParam("cudnn_off", cudnn_off)
13244  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
13245  .SetParam("stride", stride)
13246  .SetParam("pad", pad)
13247  .SetParam("p_value", p_value)
13248  .SetParam("count_include_pad", count_include_pad)
13249  .SetParam("layout", PoolingLayoutValues[int(layout)])
13250  .SetInput("data", data)
13251  .CreateSymbol();
13252 }
13253 
13286 inline Symbol softmax(Symbol data,
13287  int axis = -1,
13288  dmlc::optional<double> temperature = dmlc::optional<double>(),
13290  static const char *SoftmaxDtypeValues[] = {
13291  "None",
13292  "float16",
13293  "float32",
13294  "float64"
13295  };
13296  return Operator("softmax")
13297  .SetParam("axis", axis)
13298  .SetParam("temperature", temperature)
13299  .SetParam("dtype", SoftmaxDtypeValues[int(dtype)])
13300  .SetInput("data", data)
13301  .CreateSymbol();
13302 }
13303 
13337 inline Symbol softmin(Symbol data,
13338  int axis = -1,
13339  dmlc::optional<double> temperature = dmlc::optional<double>(),
13341  static const char *SoftminDtypeValues[] = {
13342  "None",
13343  "float16",
13344  "float32",
13345  "float64"
13346  };
13347  return Operator("softmin")
13348  .SetParam("axis", axis)
13349  .SetParam("temperature", temperature)
13350  .SetParam("dtype", SoftminDtypeValues[int(dtype)])
13351  .SetInput("data", data)
13352  .CreateSymbol();
13353 }
13354 
13379  int axis = -1,
13380  dmlc::optional<double> temperature = dmlc::optional<double>(),
13382  static const char *Log_softmaxDtypeValues[] = {
13383  "None",
13384  "float16",
13385  "float32",
13386  "float64"
13387  };
13388  return Operator("log_softmax")
13389  .SetParam("axis", axis)
13390  .SetParam("temperature", temperature)
13391  .SetParam("dtype", Log_softmaxDtypeValues[int(dtype)])
13392  .SetInput("data", data)
13393  .CreateSymbol();
13394 }
13395 
13425  Symbol weight,
13426  Symbol bias,
13427  Shape kernel,
13428  uint32_t num_filter,
13429  Shape stride = Shape(),
13430  Shape dilate = Shape(),
13431  Shape pad = Shape(),
13432  Shape adj = Shape(),
13433  Shape target_shape = Shape(),
13434  uint32_t num_group = 1,
13435  uint64_t workspace = 512,
13436  bool no_bias = true,
13438  bool cudnn_off = false,
13440  static const char *DeconvolutionCudnnTuneValues[] = {
13441  "None",
13442  "fastest",
13443  "limited_workspace",
13444  "off"
13445  };
13446  static const char *DeconvolutionLayoutValues[] = {
13447  "None",
13448  "NCDHW",
13449  "NCHW",
13450  "NCW",
13451  "NDHWC",
13452  "NHWC"
13453  };
13454  return Operator("Deconvolution")
13455  .SetParam("kernel", kernel)
13456  .SetParam("num_filter", num_filter)
13457  .SetParam("stride", stride)
13458  .SetParam("dilate", dilate)
13459  .SetParam("pad", pad)
13460  .SetParam("adj", adj)
13461  .SetParam("target_shape", target_shape)
13462  .SetParam("num_group", num_group)
13463  .SetParam("workspace", workspace)
13464  .SetParam("no_bias", no_bias)
13465  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
13466  .SetParam("cudnn_off", cudnn_off)
13467  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
13468  .SetInput("data", data)
13469  .SetInput("weight", weight)
13470  .SetInput("bias", bias)
13471  .CreateSymbol();
13472 }
13473 
13493  ActivationActType act_type) {
13494  static const char *ActivationActTypeValues[] = {
13495  "relu",
13496  "sigmoid",
13497  "softrelu",
13498  "softsign",
13499  "tanh"
13500  };
13501  return Operator("Activation")
13502  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
13503  .SetInput("data", data)
13504  .CreateSymbol();
13505 }
13506 
13575  Symbol gamma,
13576  Symbol beta,
13577  Symbol moving_mean,
13578  Symbol moving_var,
13579  double eps = 0.0010000000474974513,
13580  mx_float momentum = 0.899999976,
13581  bool fix_gamma = true,
13582  bool use_global_stats = false,
13583  bool output_mean_var = false,
13584  int axis = 1,
13585  bool cudnn_off = false) {
13586  return Operator("BatchNorm")
13587  .SetParam("eps", eps)
13588  .SetParam("momentum", momentum)
13589  .SetParam("fix_gamma", fix_gamma)
13590  .SetParam("use_global_stats", use_global_stats)
13591  .SetParam("output_mean_var", output_mean_var)
13592  .SetParam("axis", axis)
13593  .SetParam("cudnn_off", cudnn_off)
13594  .SetInput("data", data)
13595  .SetInput("gamma", gamma)
13596  .SetInput("beta", beta)
13597  .SetInput("moving_mean", moving_mean)
13598  .SetInput("moving_var", moving_var)
13599  .CreateSymbol();
13600 }
13601 
13667 inline Symbol CTCLoss(Symbol data,
13668  Symbol label,
13669  Symbol data_lengths,
13670  Symbol label_lengths,
13671  bool use_data_lengths = false,
13672  bool use_label_lengths = false,
13674  static const char *CTCLossBlankLabelValues[] = {
13675  "first",
13676  "last"
13677  };
13678  return Operator("CTCLoss")
13679  .SetParam("use_data_lengths", use_data_lengths)
13680  .SetParam("use_label_lengths", use_label_lengths)
13681  .SetParam("blank_label", CTCLossBlankLabelValues[int(blank_label)])
13682  .SetInput("data", data)
13683  .SetInput("label", label)
13684  .SetInput("data_lengths", data_lengths)
13685  .SetInput("label_lengths", label_lengths)
13686  .CreateSymbol();
13687 }
13688 
13732  Symbol weight,
13733  Symbol bias,
13734  int num_hidden,
13735  bool no_bias = false,
13736  bool flatten = true) {
13737  return Operator("FullyConnected")
13738  .SetParam("num_hidden", num_hidden)
13739  .SetParam("no_bias", no_bias)
13740  .SetParam("flatten", flatten)
13741  .SetInput("data", data)
13742  .SetInput("weight", weight)
13743  .SetInput("bias", bias)
13744  .CreateSymbol();
13745 }
13746 
13844  Symbol weight,
13845  Symbol bias,
13846  Shape kernel,
13847  uint32_t num_filter,
13848  Shape stride = Shape(),
13849  Shape dilate = Shape(),
13850  Shape pad = Shape(),
13851  uint32_t num_group = 1,
13852  uint64_t workspace = 1024,
13853  bool no_bias = false,
13855  bool cudnn_off = false,
13857  static const char *ConvolutionCudnnTuneValues[] = {
13858  "None",
13859  "fastest",
13860  "limited_workspace",
13861  "off"
13862  };
13863  static const char *ConvolutionLayoutValues[] = {
13864  "None",
13865  "NCDHW",
13866  "NCHW",
13867  "NCW",
13868  "NDHWC",
13869  "NHWC"
13870  };
13871  return Operator("Convolution")
13872  .SetParam("kernel", kernel)
13873  .SetParam("num_filter", num_filter)
13874  .SetParam("stride", stride)
13875  .SetParam("dilate", dilate)
13876  .SetParam("pad", pad)
13877  .SetParam("num_group", num_group)
13878  .SetParam("workspace", workspace)
13879  .SetParam("no_bias", no_bias)
13880  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
13881  .SetParam("cudnn_off", cudnn_off)
13882  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
13883  .SetInput("data", data)
13884  .SetInput("weight", weight)
13885  .SetInput("bias", bias)
13886  .CreateSymbol();
13887 }
13888 
13954 inline Symbol UpSampling(const std::vector<Symbol>& data,
13955  int scale,
13956  UpSamplingSampleType sample_type,
13957  int num_args,
13958  int num_filter = 0,
13960  uint64_t workspace = 512) {
13961  static const char *UpSamplingSampleTypeValues[] = {
13962  "bilinear",
13963  "nearest"
13964  };
13965  static const char *UpSamplingMultiInputModeValues[] = {
13966  "concat",
13967  "sum"
13968  };
13969  return Operator("UpSampling")
13970  .SetParam("scale", scale)
13971  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
13972  .SetParam("num_args", num_args)
13973  .SetParam("num_filter", num_filter)
13974  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
13975  .SetParam("workspace", workspace)
13976 (data)
13977  .CreateSymbol();
13978 }
13979 
14025 inline Symbol Concat(const std::vector<Symbol>& data,
14026  int num_args,
14027  int dim = 1) {
14028  return Operator("Concat")
14029  .SetParam("num_args", num_args)
14030  .SetParam("dim", dim)
14031 (data)
14032  .CreateSymbol();
14033 }
14034 
14073  Symbol gamma,
14074  Symbol beta,
14075  int axis = -1,
14076  mx_float eps = 9.99999975e-06,
14077  bool output_mean_var = false) {
14078  return Operator("LayerNorm")
14079  .SetParam("axis", axis)
14080  .SetParam("eps", eps)
14081  .SetParam("output_mean_var", output_mean_var)
14082  .SetInput("data", data)
14083  .SetInput("gamma", gamma)
14084  .SetInput("beta", beta)
14085  .CreateSymbol();
14086 }
14087 
14114 inline Symbol LRN(Symbol data,
14115  uint32_t nsize,
14116  mx_float alpha = 9.99999975e-05,
14117  mx_float beta = 0.75,
14118  mx_float knorm = 2) {
14119  return Operator("LRN")
14120  .SetParam("nsize", nsize)
14121  .SetParam("alpha", alpha)
14122  .SetParam("beta", beta)
14123  .SetParam("knorm", knorm)
14124  .SetInput("data", data)
14125  .CreateSymbol();
14126 }
14127 
14168 inline Symbol Dropout(Symbol data,
14169  mx_float p = 0.5,
14171  Shape axes = Shape(),
14172  dmlc::optional<bool> cudnn_off = dmlc::optional<bool>(0)) {
14173  static const char *DropoutModeValues[] = {
14174  "always",
14175  "training"
14176  };
14177  return Operator("Dropout")
14178  .SetParam("p", p)
14179  .SetParam("mode", DropoutModeValues[int(mode)])
14180  .SetParam("axes", axes)
14181  .SetParam("cudnn_off", cudnn_off)
14182  .SetInput("data", data)
14183  .CreateSymbol();
14184 }
14185 
14220  static const char *SoftmaxActivationModeValues[] = {
14221  "channel",
14222  "instance"
14223  };
14224  return Operator("SoftmaxActivation")
14225  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
14226  .SetInput("data", data)
14227  .CreateSymbol();
14228 }
14229 
14258 inline Symbol moments(Symbol data,
14259  dmlc::optional<Shape> axes = dmlc::optional<Shape>(),
14260  bool keepdims = false) {
14261  return Operator("moments")
14262  .SetParam("axes", axes)
14263  .SetParam("keepdims", keepdims)
14264  .SetInput("data", data)
14265  .CreateSymbol();
14266 }
14267 
14298  Symbol gamma,
14300  mx_float slope = 0.25,
14301  mx_float lower_bound = 0.125,
14302  mx_float upper_bound = 0.333999991) {
14303  static const char *LeakyReLUActTypeValues[] = {
14304  "elu",
14305  "gelu",
14306  "leaky",
14307  "prelu",
14308  "rrelu",
14309  "selu"
14310  };
14311  return Operator("LeakyReLU")
14312  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
14313  .SetParam("slope", slope)
14314  .SetParam("lower_bound", lower_bound)
14315  .SetParam("upper_bound", upper_bound)
14316  .SetInput("data", data)
14317  .SetInput("gamma", gamma)
14318  .CreateSymbol();
14319 }
14320 
14396 inline Symbol RNN(Symbol data,
14397  Symbol parameters,
14398  Symbol state,
14399  Symbol state_cell,
14400  Symbol sequence_length,
14401  uint32_t state_size,
14402  uint32_t num_layers,
14403  RNNMode mode,
14404  bool bidirectional = false,
14405  mx_float p = 0,
14406  bool state_outputs = false,
14407  dmlc::optional<int> projection_size = dmlc::optional<int>(),
14408  dmlc::optional<double> lstm_state_clip_min = dmlc::optional<double>(),
14409  dmlc::optional<double> lstm_state_clip_max = dmlc::optional<double>(),
14410  bool lstm_state_clip_nan = false,
14411  bool use_sequence_length = false) {
14412  static const char *RNNModeValues[] = {
14413  "gru",
14414  "lstm",
14415  "rnn_relu",
14416  "rnn_tanh"
14417  };
14418  return Operator("RNN")
14419  .SetParam("state_size", state_size)
14420  .SetParam("num_layers", num_layers)
14421  .SetParam("mode", RNNModeValues[int(mode)])
14422  .SetParam("bidirectional", bidirectional)
14423  .SetParam("p", p)
14424  .SetParam("state_outputs", state_outputs)
14425  .SetParam("projection_size", projection_size)
14426  .SetParam("lstm_state_clip_min", lstm_state_clip_min)
14427  .SetParam("lstm_state_clip_max", lstm_state_clip_max)
14428  .SetParam("lstm_state_clip_nan", lstm_state_clip_nan)
14429  .SetParam("use_sequence_length", use_sequence_length)
14430  .SetInput("data", data)
14431  .SetInput("parameters", parameters)
14432  .SetInput("state", state)
14433  .SetInput("state_cell", state_cell)
14434  .SetInput("sequence_length", sequence_length)
14435  .CreateSymbol();
14436 }
14437 
14533  Symbol label,
14534  mx_float grad_scale = 1,
14535  mx_float ignore_label = -1,
14536  bool multi_output = false,
14537  bool use_ignore = false,
14538  bool preserve_shape = false,
14540  bool out_grad = false,
14541  mx_float smooth_alpha = 0) {
14542  static const char *SoftmaxOutputNormalizationValues[] = {
14543  "batch",
14544  "null",
14545  "valid"
14546  };
14547  return Operator("SoftmaxOutput")
14548  .SetParam("grad_scale", grad_scale)
14549  .SetParam("ignore_label", ignore_label)
14550  .SetParam("multi_output", multi_output)
14551  .SetParam("use_ignore", use_ignore)
14552  .SetParam("preserve_shape", preserve_shape)
14553  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
14554  .SetParam("out_grad", out_grad)
14555  .SetParam("smooth_alpha", smooth_alpha)
14556  .SetInput("data", data)
14557  .SetInput("label", label)
14558  .CreateSymbol();
14559 }
14560 
14588 inline Symbol SwapAxis(Symbol data,
14589  uint32_t dim1 = 0,
14590  uint32_t dim2 = 0) {
14591  return Operator("SwapAxis")
14592  .SetParam("dim1", dim1)
14593  .SetParam("dim2", dim2)
14594  .SetInput("data", data)
14595  .CreateSymbol();
14596 }
14597 
14658  Symbol gamma,
14659  Symbol beta,
14660  mx_float eps = 0.00100000005,
14661  mx_float momentum = 0.899999976,
14662  bool fix_gamma = true,
14663  bool use_global_stats = false,
14664  bool output_mean_var = false) {
14665  return Operator("BatchNorm_v1")
14666  .SetParam("eps", eps)
14667  .SetParam("momentum", momentum)
14668  .SetParam("fix_gamma", fix_gamma)
14669  .SetParam("use_global_stats", use_global_stats)
14670  .SetParam("output_mean_var", output_mean_var)
14671  .SetInput("data", data)
14672  .SetInput("gamma", gamma)
14673  .SetInput("beta", beta)
14674  .CreateSymbol();
14675 }
14676 
14714  Symbol label) {
14715  return Operator("softmax_cross_entropy")
14716  .SetInput("data", data)
14717  .SetInput("label", label)
14718  .CreateSymbol();
14719 }
14720 
14750  Symbol label,
14751  mx_float grad_scale = 1) {
14752  return Operator("LinearRegressionOutput")
14753  .SetParam("grad_scale", grad_scale)
14754  .SetInput("data", data)
14755  .SetInput("label", label)
14756  .CreateSymbol();
14757 }
14758 
14789  Symbol label,
14790  mx_float grad_scale = 1) {
14791  return Operator("MAERegressionOutput")
14792  .SetParam("grad_scale", grad_scale)
14793  .SetInput("data", data)
14794  .SetInput("label", label)
14795  .CreateSymbol();
14796 }
14797 
14834  Symbol label,
14835  mx_float grad_scale = 1) {
14836  return Operator("LogisticRegressionOutput")
14837  .SetParam("grad_scale", grad_scale)
14838  .SetInput("data", data)
14839  .SetInput("label", label)
14840  .CreateSymbol();
14841 }
14842 
14852  mx_float sparseness_target = 0.100000001,
14853  mx_float penalty = 0.00100000005,
14854  mx_float momentum = 0.899999976) {
14855  return Operator("IdentityAttachKLSparseReg")
14856  .SetParam("sparseness_target", sparseness_target)
14857  .SetParam("penalty", penalty)
14858  .SetParam("momentum", momentum)
14859  .SetInput("data", data)
14860  .CreateSymbol();
14861 }
14862 
14891  Symbol grad,
14892  mx_float lr,
14893  mx_float wd = 0,
14894  mx_float rescale_grad = 1,
14895  mx_float clip_gradient = -1) {
14896  return Operator("signsgd_update")
14897  .SetParam("lr", lr)
14898  .SetParam("wd", wd)
14899  .SetParam("rescale_grad", rescale_grad)
14900  .SetParam("clip_gradient", clip_gradient)
14901  .SetInput("weight", weight)
14902  .SetInput("grad", grad)
14903  .CreateSymbol();
14904 }
14905 
14940  Symbol grad,
14941  Symbol mom,
14942  mx_float lr,
14943  mx_float momentum = 0,
14944  mx_float wd = 0,
14945  mx_float rescale_grad = 1,
14946  mx_float clip_gradient = -1,
14947  mx_float wd_lh = 0) {
14948  return Operator("signum_update")
14949  .SetParam("lr", lr)
14950  .SetParam("momentum", momentum)
14951  .SetParam("wd", wd)
14952  .SetParam("rescale_grad", rescale_grad)
14953  .SetParam("clip_gradient", clip_gradient)
14954  .SetParam("wd_lh", wd_lh)
14955  .SetInput("weight", weight)
14956  .SetInput("grad", grad)
14957  .SetInput("mom", mom)
14958  .CreateSymbol();
14959 }
14960 
14981 inline Symbol multi_sgd_update(const std::vector<Symbol>& data,
14982  nnvm::Tuple<mx_float> lrs,
14983  nnvm::Tuple<mx_float> wds,
14984  mx_float rescale_grad = 1,
14985  mx_float clip_gradient = -1,
14986  int num_weights = 1) {
14987  return Operator("multi_sgd_update")
14988  .SetParam("lrs", lrs)
14989  .SetParam("wds", wds)
14990  .SetParam("rescale_grad", rescale_grad)
14991  .SetParam("clip_gradient", clip_gradient)
14992  .SetParam("num_weights", num_weights)
14993 (data)
14994  .CreateSymbol();
14995 }
14996 
15030 inline Symbol multi_sgd_mom_update(const std::vector<Symbol>& data,
15031  nnvm::Tuple<mx_float> lrs,
15032  nnvm::Tuple<mx_float> wds,
15033  mx_float momentum = 0,
15034  mx_float rescale_grad = 1,
15035  mx_float clip_gradient = -1,
15036  int num_weights = 1) {
15037  return Operator("multi_sgd_mom_update")
15038  .SetParam("lrs", lrs)
15039  .SetParam("wds", wds)
15040  .SetParam("momentum", momentum)
15041  .SetParam("rescale_grad", rescale_grad)
15042  .SetParam("clip_gradient", clip_gradient)
15043  .SetParam("num_weights", num_weights)
15044 (data)
15045  .CreateSymbol();
15046 }
15047 
15068 inline Symbol multi_mp_sgd_update(const std::vector<Symbol>& data,
15069  nnvm::Tuple<mx_float> lrs,
15070  nnvm::Tuple<mx_float> wds,
15071  mx_float rescale_grad = 1,
15072  mx_float clip_gradient = -1,
15073  int num_weights = 1) {
15074  return Operator("multi_mp_sgd_update")
15075  .SetParam("lrs", lrs)
15076  .SetParam("wds", wds)
15077  .SetParam("rescale_grad", rescale_grad)
15078  .SetParam("clip_gradient", clip_gradient)
15079  .SetParam("num_weights", num_weights)
15080 (data)
15081  .CreateSymbol();
15082 }
15083 
15117 inline Symbol multi_mp_sgd_mom_update(const std::vector<Symbol>& data,
15118  nnvm::Tuple<mx_float> lrs,
15119  nnvm::Tuple<mx_float> wds,
15120  mx_float momentum = 0,
15121  mx_float rescale_grad = 1,
15122  mx_float clip_gradient = -1,
15123  int num_weights = 1) {
15124  return Operator("multi_mp_sgd_mom_update")
15125  .SetParam("lrs", lrs)
15126  .SetParam("wds", wds)
15127  .SetParam("momentum", momentum)
15128  .SetParam("rescale_grad", rescale_grad)
15129  .SetParam("clip_gradient", clip_gradient)
15130  .SetParam("num_weights", num_weights)
15131 (data)
15132  .CreateSymbol();
15133 }
15134 
15162 inline Symbol sgd_update(Symbol weight,
15163  Symbol grad,
15164  mx_float lr,
15165  mx_float wd = 0,
15166  mx_float rescale_grad = 1,
15167  mx_float clip_gradient = -1,
15168  bool lazy_update = true) {
15169  return Operator("sgd_update")
15170  .SetParam("lr", lr)
15171  .SetParam("wd", wd)
15172  .SetParam("rescale_grad", rescale_grad)
15173  .SetParam("clip_gradient", clip_gradient)
15174  .SetParam("lazy_update", lazy_update)
15175  .SetInput("weight", weight)
15176  .SetInput("grad", grad)
15177  .CreateSymbol();
15178 }
15179 
15224  Symbol grad,
15225  Symbol mom,
15226  mx_float lr,
15227  mx_float momentum = 0,
15228  mx_float wd = 0,
15229  mx_float rescale_grad = 1,
15230  mx_float clip_gradient = -1,
15231  bool lazy_update = true) {
15232  return Operator("sgd_mom_update")
15233  .SetParam("lr", lr)
15234  .SetParam("momentum", momentum)
15235  .SetParam("wd", wd)
15236  .SetParam("rescale_grad", rescale_grad)
15237  .SetParam("clip_gradient", clip_gradient)
15238  .SetParam("lazy_update", lazy_update)
15239  .SetInput("weight", weight)
15240  .SetInput("grad", grad)
15241  .SetInput("mom", mom)
15242  .CreateSymbol();
15243 }
15244 
15260  Symbol grad,
15261  Symbol weight32,
15262  mx_float lr,
15263  mx_float wd = 0,
15264  mx_float rescale_grad = 1,
15265  mx_float clip_gradient = -1,
15266  bool lazy_update = true) {
15267  return Operator("mp_sgd_update")
15268  .SetParam("lr", lr)
15269  .SetParam("wd", wd)
15270  .SetParam("rescale_grad", rescale_grad)
15271  .SetParam("clip_gradient", clip_gradient)
15272  .SetParam("lazy_update", lazy_update)
15273  .SetInput("weight", weight)
15274  .SetInput("grad", grad)
15275  .SetInput("weight32", weight32)
15276  .CreateSymbol();
15277 }
15278 
15296  Symbol grad,
15297  Symbol mom,
15298  Symbol weight32,
15299  mx_float lr,
15300  mx_float momentum = 0,
15301  mx_float wd = 0,
15302  mx_float rescale_grad = 1,
15303  mx_float clip_gradient = -1,
15304  bool lazy_update = true) {
15305  return Operator("mp_sgd_mom_update")
15306  .SetParam("lr", lr)
15307  .SetParam("momentum", momentum)
15308  .SetParam("wd", wd)
15309  .SetParam("rescale_grad", rescale_grad)
15310  .SetParam("clip_gradient", clip_gradient)
15311  .SetParam("lazy_update", lazy_update)
15312  .SetInput("weight", weight)
15313  .SetInput("grad", grad)
15314  .SetInput("mom", mom)
15315  .SetInput("weight32", weight32)
15316  .CreateSymbol();
15317 }
15318 
15353 inline Symbol ftml_update(Symbol weight,
15354  Symbol grad,
15355  Symbol d,
15356  Symbol v,
15357  Symbol z,
15358  mx_float lr,
15359  int t,
15360  mx_float beta1 = 0.600000024,
15361  mx_float beta2 = 0.999000013,
15362  double epsilon = 9.9999999392252903e-09,
15363  mx_float wd = 0,
15364  mx_float rescale_grad = 1,
15365  mx_float clip_grad = -1) {
15366  return Operator("ftml_update")
15367  .SetParam("lr", lr)
15368  .SetParam("t", t)
15369  .SetParam("beta1", beta1)
15370  .SetParam("beta2", beta2)
15371  .SetParam("epsilon", epsilon)
15372  .SetParam("wd", wd)
15373  .SetParam("rescale_grad", rescale_grad)
15374  .SetParam("clip_grad", clip_grad)
15375  .SetInput("weight", weight)
15376  .SetInput("grad", grad)
15377  .SetInput("d", d)
15378  .SetInput("v", v)
15379  .SetInput("z", z)
15380  .CreateSymbol();
15381 }
15382 
15431 inline Symbol adam_update(Symbol weight,
15432  Symbol grad,
15433  Symbol mean,
15434  Symbol var,
15435  mx_float lr,
15436  mx_float beta1 = 0.899999976,
15437  mx_float beta2 = 0.999000013,
15438  mx_float epsilon = 9.99999994e-09,
15439  mx_float wd = 0,
15440  mx_float rescale_grad = 1,
15441  mx_float clip_gradient = -1,
15442  bool lazy_update = true) {
15443  return Operator("adam_update")
15444  .SetParam("lr", lr)
15445  .SetParam("beta1", beta1)
15446  .SetParam("beta2", beta2)
15447  .SetParam("epsilon", epsilon)
15448  .SetParam("wd", wd)
15449  .SetParam("rescale_grad", rescale_grad)
15450  .SetParam("clip_gradient", clip_gradient)
15451  .SetParam("lazy_update", lazy_update)
15452  .SetInput("weight", weight)
15453  .SetInput("grad", grad)
15454  .SetInput("mean", mean)
15455  .SetInput("var", var)
15456  .CreateSymbol();
15457 }
15458 
15489  Symbol grad,
15490  Symbol mom,
15491  mx_float lr,
15492  mx_float momentum = 0,
15493  mx_float wd = 0,
15494  mx_float rescale_grad = 1,
15495  mx_float clip_gradient = -1) {
15496  return Operator("nag_mom_update")
15497  .SetParam("lr", lr)
15498  .SetParam("momentum", momentum)
15499  .SetParam("wd", wd)
15500  .SetParam("rescale_grad", rescale_grad)
15501  .SetParam("clip_gradient", clip_gradient)
15502  .SetInput("weight", weight)
15503  .SetInput("grad", grad)
15504  .SetInput("mom", mom)
15505  .CreateSymbol();
15506 }
15507 
15527  Symbol grad,
15528  Symbol mom,
15529  Symbol weight32,
15530  mx_float lr,
15531  mx_float momentum = 0,
15532  mx_float wd = 0,
15533  mx_float rescale_grad = 1,
15534  mx_float clip_gradient = -1) {
15535  return Operator("mp_nag_mom_update")
15536  .SetParam("lr", lr)
15537  .SetParam("momentum", momentum)
15538  .SetParam("wd", wd)
15539  .SetParam("rescale_grad", rescale_grad)
15540  .SetParam("clip_gradient", clip_gradient)
15541  .SetInput("weight", weight)
15542  .SetInput("grad", grad)
15543  .SetInput("mom", mom)
15544  .SetInput("weight32", weight32)
15545  .CreateSymbol();
15546 }
15547 
15601  Symbol grad,
15602  Symbol n,
15603  mx_float lr,
15604  mx_float gamma1 = 0.949999988,
15605  mx_float epsilon = 9.99999994e-09,
15606  mx_float wd = 0,
15607  mx_float rescale_grad = 1,
15608  mx_float clip_gradient = -1,
15609  mx_float clip_weights = -1) {
15610  return Operator("rmsprop_update")
15611  .SetParam("lr", lr)
15612  .SetParam("gamma1", gamma1)
15613  .SetParam("epsilon", epsilon)
15614  .SetParam("wd", wd)
15615  .SetParam("rescale_grad", rescale_grad)
15616  .SetParam("clip_gradient", clip_gradient)
15617  .SetParam("clip_weights", clip_weights)
15618  .SetInput("weight", weight)
15619  .SetInput("grad", grad)
15620  .SetInput("n", n)
15621  .CreateSymbol();
15622 }
15623 
15669  Symbol grad,
15670  Symbol n,
15671  Symbol g,
15672  Symbol delta,
15673  mx_float lr,
15674  mx_float gamma1 = 0.949999988,
15675  mx_float gamma2 = 0.899999976,
15676  mx_float epsilon = 9.99999994e-09,
15677  mx_float wd = 0,
15678  mx_float rescale_grad = 1,
15679  mx_float clip_gradient = -1,
15680  mx_float clip_weights = -1) {
15681  return Operator("rmspropalex_update")
15682  .SetParam("lr", lr)
15683  .SetParam("gamma1", gamma1)
15684  .SetParam("gamma2", gamma2)
15685  .SetParam("epsilon", epsilon)
15686  .SetParam("wd", wd)
15687  .SetParam("rescale_grad", rescale_grad)
15688  .SetParam("clip_gradient", clip_gradient)
15689  .SetParam("clip_weights", clip_weights)
15690  .SetInput("weight", weight)
15691  .SetInput("grad", grad)
15692  .SetInput("n", n)
15693  .SetInput("g", g)
15694  .SetInput("delta", delta)
15695  .CreateSymbol();
15696 }
15697 
15736 inline Symbol ftrl_update(Symbol weight,
15737  Symbol grad,
15738  Symbol z,
15739  Symbol n,
15740  mx_float lr,
15741  mx_float lamda1 = 0.00999999978,
15742  mx_float beta = 1,
15743  mx_float wd = 0,
15744  mx_float rescale_grad = 1,
15745  mx_float clip_gradient = -1) {
15746  return Operator("ftrl_update")
15747  .SetParam("lr", lr)
15748  .SetParam("lamda1", lamda1)
15749  .SetParam("beta", beta)
15750  .SetParam("wd", wd)
15751  .SetParam("rescale_grad", rescale_grad)
15752  .SetParam("clip_gradient", clip_gradient)
15753  .SetInput("weight", weight)
15754  .SetInput("grad", grad)
15755  .SetInput("z", z)
15756  .SetInput("n", n)
15757  .CreateSymbol();
15758 }
15759 
15831  int num_outputs,
15832  int axis = 1,
15833  bool squeeze_axis = false) {
15834  return Operator("SliceChannel")
15835  .SetParam("num_outputs", num_outputs)
15836  .SetParam("axis", axis)
15837  .SetParam("squeeze_axis", squeeze_axis)
15838  .SetInput("data", data)
15839  .CreateSymbol();
15840 }
15841 
15937 inline Symbol Pad(Symbol data,
15938  PadMode mode,
15939  Shape pad_width,
15940  double constant_value = 0) {
15941  static const char *PadModeValues[] = {
15942  "constant",
15943  "edge",
15944  "reflect"
15945  };
15946  return Operator("Pad")
15947  .SetParam("mode", PadModeValues[int(mode)])
15948  .SetParam("pad_width", pad_width)
15949  .SetParam("constant_value", constant_value)
15950  .SetInput("data", data)
15951  .CreateSymbol();
15952 }
15953 
16004  Symbol gamma,
16005  Symbol beta,
16006  mx_float eps = 0.00100000005) {
16007  return Operator("InstanceNorm")
16008  .SetParam("eps", eps)
16009  .SetInput("data", data)
16010  .SetInput("gamma", gamma)
16011  .SetInput("beta", beta)
16012  .CreateSymbol();
16013 }
16014 
16025  GridGeneratorTransformType transform_type,
16026  Shape target_shape = Shape(0,0)) {
16027  static const char *GridGeneratorTransformTypeValues[] = {
16028  "affine",
16029  "warp"
16030  };
16031  return Operator("GridGenerator")
16032  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
16033  .SetParam("target_shape", target_shape)
16034  .SetInput("data", data)
16035  .CreateSymbol();
16036 }
16037 
16089  Shape kernel = Shape(),
16091  bool global_pool = false,
16093  Shape stride = Shape(),
16094  Shape pad = Shape()) {
16095  static const char *Pooling_v1PoolTypeValues[] = {
16096  "avg",
16097  "max",
16098  "sum"
16099  };
16100  static const char *Pooling_v1PoolingConventionValues[] = {
16101  "full",
16102  "valid"
16103  };
16104  return Operator("Pooling_v1")
16105  .SetParam("kernel", kernel)
16106  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
16107  .SetParam("global_pool", global_pool)
16108  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
16109  .SetParam("stride", stride)
16110  .SetParam("pad", pad)
16111  .SetInput("data", data)
16112  .CreateSymbol();
16113 }
16114 
16145  Symbol weight,
16146  Symbol bias,
16147  Shape kernel,
16148  uint32_t num_filter,
16149  Shape stride = Shape(),
16150  Shape dilate = Shape(),
16151  Shape pad = Shape(),
16152  uint32_t num_group = 1,
16153  uint64_t workspace = 1024,
16154  bool no_bias = false,
16156  bool cudnn_off = false,
16158  static const char *Convolution_v1CudnnTuneValues[] = {
16159  "None",
16160  "fastest",
16161  "limited_workspace",
16162  "off"
16163  };
16164  static const char *Convolution_v1LayoutValues[] = {
16165  "None",
16166  "NCDHW",
16167  "NCHW",
16168  "NDHWC",
16169  "NHWC"
16170  };
16171  return Operator("Convolution_v1")
16172  .SetParam("kernel", kernel)
16173  .SetParam("num_filter", num_filter)
16174  .SetParam("stride", stride)
16175  .Se