mxnet
op.h
Go to the documentation of this file.
1 
8 #ifndef MXNET_CPP_OP_H_
9 #define MXNET_CPP_OP_H_
10 
11 #include <string>
12 #include <vector>
13 #include "mxnet-cpp/base.h"
14 #include "mxnet-cpp/shape.h"
15 #include "mxnet-cpp/op_util.h"
16 #include "mxnet-cpp/operator.h"
17 #include "dmlc/optional.h"
18 
19 namespace mxnet {
20 namespace cpp {
21 
51 inline Symbol softmax(const std::string& symbol_name,
52  Symbol data,
53  int axis = -1) {
54  return Operator("softmax")
55  .SetParam("axis", axis)
56  .SetInput("data", data)
57  .CreateSymbol(symbol_name);
58 }
59 
82 inline Symbol log_softmax(const std::string& symbol_name,
83  Symbol data,
84  int axis = -1) {
85  return Operator("log_softmax")
86  .SetParam("axis", axis)
87  .SetInput("data", data)
88  .CreateSymbol(symbol_name);
89 }
90 
113 inline Symbol broadcast_power(const std::string& symbol_name,
114  Symbol lhs,
115  Symbol rhs) {
116  return Operator("broadcast_power")
117  .SetInput("lhs", lhs)
118  .SetInput("rhs", rhs)
119  .CreateSymbol(symbol_name);
120 }
121 
146 inline Symbol broadcast_maximum(const std::string& symbol_name,
147  Symbol lhs,
148  Symbol rhs) {
149  return Operator("broadcast_maximum")
150  .SetInput("lhs", lhs)
151  .SetInput("rhs", rhs)
152  .CreateSymbol(symbol_name);
153 }
154 
179 inline Symbol broadcast_minimum(const std::string& symbol_name,
180  Symbol lhs,
181  Symbol rhs) {
182  return Operator("broadcast_minimum")
183  .SetInput("lhs", lhs)
184  .SetInput("rhs", rhs)
185  .CreateSymbol(symbol_name);
186 }
187 
218 inline Symbol broadcast_hypot(const std::string& symbol_name,
219  Symbol lhs,
220  Symbol rhs) {
221  return Operator("broadcast_hypot")
222  .SetInput("lhs", lhs)
223  .SetInput("rhs", rhs)
224  .CreateSymbol(symbol_name);
225 }
226 
301 inline Symbol Reshape(const std::string& symbol_name,
302  Symbol data,
303  Shape shape = Shape(),
304  bool reverse = 0,
305  Shape target_shape = Shape(),
306  bool keep_highest = 0) {
307  return Operator("Reshape")
308  .SetParam("shape", shape)
309  .SetParam("reverse", reverse)
310  .SetParam("target_shape", target_shape)
311  .SetParam("keep_highest", keep_highest)
312  .SetInput("data", data)
313  .CreateSymbol(symbol_name);
314 }
315 
346 inline Symbol Flatten(const std::string& symbol_name,
347  Symbol data) {
348  return Operator("Flatten")
349  .SetInput("data", data)
350  .CreateSymbol(symbol_name);
351 }
352 
389 inline Symbol transpose(const std::string& symbol_name,
390  Symbol data,
391  Shape axes = Shape()) {
392  return Operator("transpose")
393  .SetParam("axes", axes)
394  .SetInput("data", data)
395  .CreateSymbol(symbol_name);
396 }
397 
413 inline Symbol expand_dims(const std::string& symbol_name,
414  Symbol data,
415  int axis) {
416  return Operator("expand_dims")
417  .SetParam("axis", axis)
418  .SetInput("data", data)
419  .CreateSymbol(symbol_name);
420 }
421 
458 inline Symbol slice(const std::string& symbol_name,
459  Symbol data,
460  Shape begin,
461  Shape end) {
462  return Operator("slice")
463  .SetParam("begin", begin)
464  .SetParam("end", end)
465  .SetInput("data", data)
466  .CreateSymbol(symbol_name);
467 }
468 
501 inline Symbol slice_axis(const std::string& symbol_name,
502  Symbol data,
503  int axis,
504  int begin,
505  dmlc::optional<int> end) {
506  return Operator("slice_axis")
507  .SetParam("axis", axis)
508  .SetParam("begin", begin)
509  .SetParam("end", end)
510  .SetInput("data", data)
511  .CreateSymbol(symbol_name);
512 }
513 
548 inline Symbol clip(const std::string& symbol_name,
549  Symbol data,
550  mx_float a_min,
551  mx_float a_max) {
552  return Operator("clip")
553  .SetParam("a_min", a_min)
554  .SetParam("a_max", a_max)
555  .SetInput("data", data)
556  .CreateSymbol(symbol_name);
557 }
558 
593 inline Symbol repeat(const std::string& symbol_name,
594  Symbol data,
595  int repeats,
596  dmlc::optional<int> axis = dmlc::optional<int>()) {
597  return Operator("repeat")
598  .SetParam("repeats", repeats)
599  .SetParam("axis", axis)
600  .SetInput("data", data)
601  .CreateSymbol(symbol_name);
602 }
603 
649 inline Symbol tile(const std::string& symbol_name,
650  Symbol data,
651  Shape reps) {
652  return Operator("tile")
653  .SetParam("reps", reps)
654  .SetInput("data", data)
655  .CreateSymbol(symbol_name);
656 }
657 
681 inline Symbol reverse(const std::string& symbol_name,
682  Symbol data,
683  Shape axis) {
684  return Operator("reverse")
685  .SetParam("axis", axis)
686  .SetInput("data", data)
687  .CreateSymbol(symbol_name);
688 }
689 
713 inline Symbol stack(const std::string& symbol_name,
714  const std::vector<Symbol>& data,
715  int num_args,
716  int axis = 0) {
717  return Operator("stack")
718  .SetParam("num_args", num_args)
719  .SetParam("axis", axis)
720 (data)
721  .CreateSymbol(symbol_name);
722 }
723 
747 inline Symbol zeros_like(const std::string& symbol_name,
748  Symbol data) {
749  return Operator("zeros_like")
750  .SetInput("data", data)
751  .CreateSymbol(symbol_name);
752 }
753 
771 inline Symbol ones_like(const std::string& symbol_name,
772  Symbol data) {
773  return Operator("ones_like")
774  .SetInput("data", data)
775  .CreateSymbol(symbol_name);
776 }
777 
805 inline Symbol broadcast_add(const std::string& symbol_name,
806  Symbol lhs,
807  Symbol rhs) {
808  return Operator("broadcast_add")
809  .SetInput("lhs", lhs)
810  .SetInput("rhs", rhs)
811  .CreateSymbol(symbol_name);
812 }
813 
841 inline Symbol broadcast_sub(const std::string& symbol_name,
842  Symbol lhs,
843  Symbol rhs) {
844  return Operator("broadcast_sub")
845  .SetInput("lhs", lhs)
846  .SetInput("rhs", rhs)
847  .CreateSymbol(symbol_name);
848 }
849 
872 inline Symbol broadcast_mul(const std::string& symbol_name,
873  Symbol lhs,
874  Symbol rhs) {
875  return Operator("broadcast_mul")
876  .SetInput("lhs", lhs)
877  .SetInput("rhs", rhs)
878  .CreateSymbol(symbol_name);
879 }
880 
903 inline Symbol broadcast_div(const std::string& symbol_name,
904  Symbol lhs,
905  Symbol rhs) {
906  return Operator("broadcast_div")
907  .SetInput("lhs", lhs)
908  .SetInput("rhs", rhs)
909  .CreateSymbol(symbol_name);
910 }
911 
934 inline Symbol broadcast_mod(const std::string& symbol_name,
935  Symbol lhs,
936  Symbol rhs) {
937  return Operator("broadcast_mod")
938  .SetInput("lhs", lhs)
939  .SetInput("rhs", rhs)
940  .CreateSymbol(symbol_name);
941 }
942 
963 inline Symbol add_n(const std::string& symbol_name,
964  const std::vector<Symbol>& args) {
965  return Operator("add_n")
966 (args)
967  .CreateSymbol(symbol_name);
968 }
969 
1001 inline Symbol argmax(const std::string& symbol_name,
1002  Symbol data,
1003  dmlc::optional<int> axis = dmlc::optional<int>(),
1004  bool keepdims = 0) {
1005  return Operator("argmax")
1006  .SetParam("axis", axis)
1007  .SetParam("keepdims", keepdims)
1008  .SetInput("data", data)
1009  .CreateSymbol(symbol_name);
1010 }
1011 
1043 inline Symbol argmin(const std::string& symbol_name,
1044  Symbol data,
1045  dmlc::optional<int> axis = dmlc::optional<int>(),
1046  bool keepdims = 0) {
1047  return Operator("argmin")
1048  .SetParam("axis", axis)
1049  .SetParam("keepdims", keepdims)
1050  .SetInput("data", data)
1051  .CreateSymbol(symbol_name);
1052 }
1053 
1076 inline Symbol argmax_channel(const std::string& symbol_name,
1077  Symbol data) {
1078  return Operator("argmax_channel")
1079  .SetInput("data", data)
1080  .CreateSymbol(symbol_name);
1081 }
1082 
1128 inline Symbol pick(const std::string& symbol_name,
1129  Symbol data,
1130  Symbol index,
1131  dmlc::optional<int> axis = dmlc::optional<int>(),
1132  bool keepdims = 0) {
1133  return Operator("pick")
1134  .SetParam("axis", axis)
1135  .SetParam("keepdims", keepdims)
1136  .SetInput("data", data)
1137  .SetInput("index", index)
1138  .CreateSymbol(symbol_name);
1139 }
1140 
1180 inline Symbol dot(const std::string& symbol_name,
1181  Symbol lhs,
1182  Symbol rhs,
1183  bool transpose_a = 0,
1184  bool transpose_b = 0) {
1185  return Operator("dot")
1186  .SetParam("transpose_a", transpose_a)
1187  .SetParam("transpose_b", transpose_b)
1188  .SetInput("lhs", lhs)
1189  .SetInput("rhs", rhs)
1190  .CreateSymbol(symbol_name);
1191 }
1192 
1215 inline Symbol batch_dot(const std::string& symbol_name,
1216  Symbol lhs,
1217  Symbol rhs,
1218  bool transpose_a = 0,
1219  bool transpose_b = 0) {
1220  return Operator("batch_dot")
1221  .SetParam("transpose_a", transpose_a)
1222  .SetParam("transpose_b", transpose_b)
1223  .SetInput("lhs", lhs)
1224  .SetInput("rhs", rhs)
1225  .CreateSymbol(symbol_name);
1226 }
1227 
1246 inline Symbol relu(const std::string& symbol_name,
1247  Symbol data) {
1248  return Operator("relu")
1249  .SetInput("data", data)
1250  .CreateSymbol(symbol_name);
1251 }
1252 
1268 inline Symbol sigmoid(const std::string& symbol_name,
1269  Symbol data) {
1270  return Operator("sigmoid")
1271  .SetInput("data", data)
1272  .CreateSymbol(symbol_name);
1273 }
1274 
1308 inline Symbol BlockGrad(const std::string& symbol_name,
1309  Symbol data) {
1310  return Operator("BlockGrad")
1311  .SetInput("data", data)
1312  .CreateSymbol(symbol_name);
1313 }
1314 
1344 inline Symbol make_loss(const std::string& symbol_name,
1345  Symbol data) {
1346  return Operator("make_loss")
1347  .SetInput("data", data)
1348  .CreateSymbol(symbol_name);
1349 }
1350 
1358 inline Symbol reshape_like(const std::string& symbol_name,
1359  Symbol lhs,
1360  Symbol rhs) {
1361  return Operator("reshape_like")
1362  .SetInput("lhs", lhs)
1363  .SetInput("rhs", rhs)
1364  .CreateSymbol(symbol_name);
1365 }
1366 
1369 enum class CastDtype {
1370  kFloat16 = 0,
1371  kFloat32 = 1,
1372  kFloat64 = 2,
1373  kInt32 = 3,
1374  kUint8 = 4
1375 };
1376 
1396 inline Symbol Cast(const std::string& symbol_name,
1397  Symbol data,
1398  CastDtype dtype) {
1399  static const char *CastDtypeValues[] = {
1400  "float16",
1401  "float32",
1402  "float64",
1403  "int32",
1404  "uint8"
1405  };
1406  return Operator("Cast")
1407  .SetParam("dtype", CastDtypeValues[int(dtype)])
1408  .SetInput("data", data)
1409  .CreateSymbol(symbol_name);
1410 }
1411 
1426 inline Symbol negative(const std::string& symbol_name,
1427  Symbol data) {
1428  return Operator("negative")
1429  .SetInput("data", data)
1430  .CreateSymbol(symbol_name);
1431 }
1432 
1449 inline Symbol reciprocal(const std::string& symbol_name,
1450  Symbol data) {
1451  return Operator("reciprocal")
1452  .SetInput("data", data)
1453  .CreateSymbol(symbol_name);
1454 }
1455 
1475 inline Symbol abs(const std::string& symbol_name,
1476  Symbol data) {
1477  return Operator("abs")
1478  .SetInput("data", data)
1479  .CreateSymbol(symbol_name);
1480 }
1481 
1501 inline Symbol sign(const std::string& symbol_name,
1502  Symbol data) {
1503  return Operator("sign")
1504  .SetInput("data", data)
1505  .CreateSymbol(symbol_name);
1506 }
1507 
1527 inline Symbol round(const std::string& symbol_name,
1528  Symbol data) {
1529  return Operator("round")
1530  .SetInput("data", data)
1531  .CreateSymbol(symbol_name);
1532 }
1533 
1557 inline Symbol rint(const std::string& symbol_name,
1558  Symbol data) {
1559  return Operator("rint")
1560  .SetInput("data", data)
1561  .CreateSymbol(symbol_name);
1562 }
1563 
1585 inline Symbol ceil(const std::string& symbol_name,
1586  Symbol data) {
1587  return Operator("ceil")
1588  .SetInput("data", data)
1589  .CreateSymbol(symbol_name);
1590 }
1591 
1613 inline Symbol floor(const std::string& symbol_name,
1614  Symbol data) {
1615  return Operator("floor")
1616  .SetInput("data", data)
1617  .CreateSymbol(symbol_name);
1618 }
1619 
1642 inline Symbol trunc(const std::string& symbol_name,
1643  Symbol data) {
1644  return Operator("trunc")
1645  .SetInput("data", data)
1646  .CreateSymbol(symbol_name);
1647 }
1648 
1669 inline Symbol fix(const std::string& symbol_name,
1670  Symbol data) {
1671  return Operator("fix")
1672  .SetInput("data", data)
1673  .CreateSymbol(symbol_name);
1674 }
1675 
1699 inline Symbol square(const std::string& symbol_name,
1700  Symbol data) {
1701  return Operator("square")
1702  .SetInput("data", data)
1703  .CreateSymbol(symbol_name);
1704 }
1705 
1728 inline Symbol sqrt(const std::string& symbol_name,
1729  Symbol data) {
1730  return Operator("sqrt")
1731  .SetInput("data", data)
1732  .CreateSymbol(symbol_name);
1733 }
1734 
1754 inline Symbol rsqrt(const std::string& symbol_name,
1755  Symbol data) {
1756  return Operator("rsqrt")
1757  .SetInput("data", data)
1758  .CreateSymbol(symbol_name);
1759 }
1760 
1778 inline Symbol cbrt(const std::string& symbol_name,
1779  Symbol data) {
1780  return Operator("cbrt")
1781  .SetInput("data", data)
1782  .CreateSymbol(symbol_name);
1783 }
1784 
1802 inline Symbol rcbrt(const std::string& symbol_name,
1803  Symbol data) {
1804  return Operator("rcbrt")
1805  .SetInput("data", data)
1806  .CreateSymbol(symbol_name);
1807 }
1808 
1828 inline Symbol exp(const std::string& symbol_name,
1829  Symbol data) {
1830  return Operator("exp")
1831  .SetInput("data", data)
1832  .CreateSymbol(symbol_name);
1833 }
1834 
1849 inline Symbol log(const std::string& symbol_name,
1850  Symbol data) {
1851  return Operator("log")
1852  .SetInput("data", data)
1853  .CreateSymbol(symbol_name);
1854 }
1855 
1870 inline Symbol log10(const std::string& symbol_name,
1871  Symbol data) {
1872  return Operator("log10")
1873  .SetInput("data", data)
1874  .CreateSymbol(symbol_name);
1875 }
1876 
1891 inline Symbol log2(const std::string& symbol_name,
1892  Symbol data) {
1893  return Operator("log2")
1894  .SetInput("data", data)
1895  .CreateSymbol(symbol_name);
1896 }
1897 
1916 inline Symbol log1p(const std::string& symbol_name,
1917  Symbol data) {
1918  return Operator("log1p")
1919  .SetInput("data", data)
1920  .CreateSymbol(symbol_name);
1921 }
1922 
1940 inline Symbol expm1(const std::string& symbol_name,
1941  Symbol data) {
1942  return Operator("expm1")
1943  .SetInput("data", data)
1944  .CreateSymbol(symbol_name);
1945 }
1946 
1958 inline Symbol gamma(const std::string& symbol_name,
1959  Symbol data) {
1960  return Operator("gamma")
1961  .SetInput("data", data)
1962  .CreateSymbol(symbol_name);
1963 }
1964 
1976 inline Symbol gammaln(const std::string& symbol_name,
1977  Symbol data) {
1978  return Operator("gammaln")
1979  .SetInput("data", data)
1980  .CreateSymbol(symbol_name);
1981 }
1982 
2041 inline Symbol sum(const std::string& symbol_name,
2042  Symbol data,
2043  Shape axis = Shape(),
2044  bool keepdims = 0,
2045  bool exclude = 0) {
2046  return Operator("sum")
2047  .SetParam("axis", axis)
2048  .SetParam("keepdims", keepdims)
2049  .SetParam("exclude", exclude)
2050  .SetInput("data", data)
2051  .CreateSymbol(symbol_name);
2052 }
2053 
2078 inline Symbol mean(const std::string& symbol_name,
2079  Symbol data,
2080  Shape axis = Shape(),
2081  bool keepdims = 0,
2082  bool exclude = 0) {
2083  return Operator("mean")
2084  .SetParam("axis", axis)
2085  .SetParam("keepdims", keepdims)
2086  .SetParam("exclude", exclude)
2087  .SetInput("data", data)
2088  .CreateSymbol(symbol_name);
2089 }
2090 
2115 inline Symbol prod(const std::string& symbol_name,
2116  Symbol data,
2117  Shape axis = Shape(),
2118  bool keepdims = 0,
2119  bool exclude = 0) {
2120  return Operator("prod")
2121  .SetParam("axis", axis)
2122  .SetParam("keepdims", keepdims)
2123  .SetParam("exclude", exclude)
2124  .SetInput("data", data)
2125  .CreateSymbol(symbol_name);
2126 }
2127 
2154 inline Symbol nansum(const std::string& symbol_name,
2155  Symbol data,
2156  Shape axis = Shape(),
2157  bool keepdims = 0,
2158  bool exclude = 0) {
2159  return Operator("nansum")
2160  .SetParam("axis", axis)
2161  .SetParam("keepdims", keepdims)
2162  .SetParam("exclude", exclude)
2163  .SetInput("data", data)
2164  .CreateSymbol(symbol_name);
2165 }
2166 
2193 inline Symbol nanprod(const std::string& symbol_name,
2194  Symbol data,
2195  Shape axis = Shape(),
2196  bool keepdims = 0,
2197  bool exclude = 0) {
2198  return Operator("nanprod")
2199  .SetParam("axis", axis)
2200  .SetParam("keepdims", keepdims)
2201  .SetParam("exclude", exclude)
2202  .SetInput("data", data)
2203  .CreateSymbol(symbol_name);
2204 }
2205 
2230 inline Symbol max(const std::string& symbol_name,
2231  Symbol data,
2232  Shape axis = Shape(),
2233  bool keepdims = 0,
2234  bool exclude = 0) {
2235  return Operator("max")
2236  .SetParam("axis", axis)
2237  .SetParam("keepdims", keepdims)
2238  .SetParam("exclude", exclude)
2239  .SetInput("data", data)
2240  .CreateSymbol(symbol_name);
2241 }
2242 
2267 inline Symbol min(const std::string& symbol_name,
2268  Symbol data,
2269  Shape axis = Shape(),
2270  bool keepdims = 0,
2271  bool exclude = 0) {
2272  return Operator("min")
2273  .SetParam("axis", axis)
2274  .SetParam("keepdims", keepdims)
2275  .SetParam("exclude", exclude)
2276  .SetInput("data", data)
2277  .CreateSymbol(symbol_name);
2278 }
2279 
2309 inline Symbol broadcast_axis(const std::string& symbol_name,
2310  Symbol data,
2311  Shape axis = Shape(),
2312  Shape size = Shape()) {
2313  return Operator("broadcast_axis")
2314  .SetParam("axis", axis)
2315  .SetParam("size", size)
2316  .SetInput("data", data)
2317  .CreateSymbol(symbol_name);
2318 }
2319 
2348 inline Symbol broadcast_to(const std::string& symbol_name,
2349  Symbol data,
2350  Shape shape = Shape()) {
2351  return Operator("broadcast_to")
2352  .SetParam("shape", shape)
2353  .SetInput("data", data)
2354  .CreateSymbol(symbol_name);
2355 }
2356 
2374 inline Symbol norm(const std::string& symbol_name,
2375  Symbol data) {
2376  return Operator("norm")
2377  .SetInput("data", data)
2378  .CreateSymbol(symbol_name);
2379 }
2380 
2386 enum class TopkRetTyp {
2387  kBoth = 0,
2388  kIndices = 1,
2389  kMask = 2,
2390  kValue = 3
2391 };
2392 
2434 inline Symbol topk(const std::string& symbol_name,
2435  Symbol data,
2436  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2437  int k = 1,
2438  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
2439  bool is_ascend = 0) {
2440  static const char *TopkRetTypValues[] = {
2441  "both",
2442  "indices",
2443  "mask",
2444  "value"
2445  };
2446  return Operator("topk")
2447  .SetParam("axis", axis)
2448  .SetParam("k", k)
2449  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
2450  .SetParam("is_ascend", is_ascend)
2451  .SetInput("data", data)
2452  .CreateSymbol(symbol_name);
2453 }
2454 
2487 inline Symbol sort(const std::string& symbol_name,
2488  Symbol data,
2489  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2490  bool is_ascend = 1) {
2491  return Operator("sort")
2492  .SetParam("axis", axis)
2493  .SetParam("is_ascend", is_ascend)
2494  .SetInput("data", data)
2495  .CreateSymbol(symbol_name);
2496 }
2497 
2528 inline Symbol argsort(const std::string& symbol_name,
2529  Symbol data,
2530  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2531  bool is_ascend = 1) {
2532  return Operator("argsort")
2533  .SetParam("axis", axis)
2534  .SetParam("is_ascend", is_ascend)
2535  .SetInput("data", data)
2536  .CreateSymbol(symbol_name);
2537 }
2538 
2553 inline Symbol elemwise_add(const std::string& symbol_name,
2554  Symbol lhs,
2555  Symbol rhs) {
2556  return Operator("elemwise_add")
2557  .SetInput("lhs", lhs)
2558  .SetInput("rhs", rhs)
2559  .CreateSymbol(symbol_name);
2560 }
2561 
2576 inline Symbol elemwise_sub(const std::string& symbol_name,
2577  Symbol lhs,
2578  Symbol rhs) {
2579  return Operator("elemwise_sub")
2580  .SetInput("lhs", lhs)
2581  .SetInput("rhs", rhs)
2582  .CreateSymbol(symbol_name);
2583 }
2584 
2602 inline Symbol elemwise_mul(const std::string& symbol_name,
2603  Symbol lhs,
2604  Symbol rhs) {
2605  return Operator("elemwise_mul")
2606  .SetInput("lhs", lhs)
2607  .SetInput("rhs", rhs)
2608  .CreateSymbol(symbol_name);
2609 }
2610 
2622 inline Symbol elemwise_div(const std::string& symbol_name,
2623  Symbol lhs,
2624  Symbol rhs) {
2625  return Operator("elemwise_div")
2626  .SetInput("lhs", lhs)
2627  .SetInput("rhs", rhs)
2628  .CreateSymbol(symbol_name);
2629 }
2630 
2633 enum class EmbeddingDtype {
2634  kFloat16 = 0,
2635  kFloat32 = 1,
2636  kFloat64 = 2,
2637  kInt32 = 3,
2638  kUint8 = 4
2639 };
2640 
2692 inline Symbol Embedding(const std::string& symbol_name,
2693  Symbol data,
2694  Symbol weight,
2695  int input_dim,
2696  int output_dim,
2698  static const char *EmbeddingDtypeValues[] = {
2699  "float16",
2700  "float32",
2701  "float64",
2702  "int32",
2703  "uint8"
2704  };
2705  return Operator("Embedding")
2706  .SetParam("input_dim", input_dim)
2707  .SetParam("output_dim", output_dim)
2708  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
2709  .SetInput("data", data)
2710  .SetInput("weight", weight)
2711  .CreateSymbol(symbol_name);
2712 }
2713 
2718 enum class TakeMode {
2719  kClip = 0,
2720  kRaise = 1,
2721  kWrap = 2
2722 };
2723 
2763 inline Symbol take(const std::string& symbol_name,
2764  Symbol a,
2765  Symbol indices,
2766  int axis = 0,
2767  TakeMode mode = TakeMode::kClip) {
2768  static const char *TakeModeValues[] = {
2769  "clip",
2770  "raise",
2771  "wrap"
2772  };
2773  return Operator("take")
2774  .SetParam("axis", axis)
2775  .SetParam("mode", TakeModeValues[int(mode)])
2776  .SetInput("a", a)
2777  .SetInput("indices", indices)
2778  .CreateSymbol(symbol_name);
2779 }
2780 
2809 inline Symbol batch_take(const std::string& symbol_name,
2810  Symbol a,
2811  Symbol indices) {
2812  return Operator("batch_take")
2813  .SetInput("a", a)
2814  .SetInput("indices", indices)
2815  .CreateSymbol(symbol_name);
2816 }
2817 
2820 enum class One_hotDtype {
2821  kFloat16 = 0,
2822  kFloat32 = 1,
2823  kFloat64 = 2,
2824  kInt32 = 3,
2825  kUint8 = 4
2826 };
2827 
2872 inline Symbol one_hot(const std::string& symbol_name,
2873  Symbol indices,
2874  int depth,
2875  double on_value = 1,
2876  double off_value = 0,
2878  static const char *One_hotDtypeValues[] = {
2879  "float16",
2880  "float32",
2881  "float64",
2882  "int32",
2883  "uint8"
2884  };
2885  return Operator("one_hot")
2886  .SetParam("depth", depth)
2887  .SetParam("on_value", on_value)
2888  .SetParam("off_value", off_value)
2889  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
2890  .SetInput("indices", indices)
2891  .CreateSymbol(symbol_name);
2892 }
2893 
2922 inline Symbol gather_nd(const std::string& symbol_name,
2923  Symbol data,
2924  Symbol indices) {
2925  return Operator("gather_nd")
2926  .SetInput("data", data)
2927  .SetInput("indices", indices)
2928  .CreateSymbol(symbol_name);
2929 }
2930 
2961 inline Symbol scatter_nd(const std::string& symbol_name,
2962  Symbol data,
2963  Symbol indices,
2964  Shape shape) {
2965  return Operator("scatter_nd")
2966  .SetParam("shape", shape)
2967  .SetInput("data", data)
2968  .SetInput("indices", indices)
2969  .CreateSymbol(symbol_name);
2970 }
2971 
2994 inline Symbol broadcast_equal(const std::string& symbol_name,
2995  Symbol lhs,
2996  Symbol rhs) {
2997  return Operator("broadcast_equal")
2998  .SetInput("lhs", lhs)
2999  .SetInput("rhs", rhs)
3000  .CreateSymbol(symbol_name);
3001 }
3002 
3025 inline Symbol broadcast_not_equal(const std::string& symbol_name,
3026  Symbol lhs,
3027  Symbol rhs) {
3028  return Operator("broadcast_not_equal")
3029  .SetInput("lhs", lhs)
3030  .SetInput("rhs", rhs)
3031  .CreateSymbol(symbol_name);
3032 }
3033 
3056 inline Symbol broadcast_greater(const std::string& symbol_name,
3057  Symbol lhs,
3058  Symbol rhs) {
3059  return Operator("broadcast_greater")
3060  .SetInput("lhs", lhs)
3061  .SetInput("rhs", rhs)
3062  .CreateSymbol(symbol_name);
3063 }
3064 
3087 inline Symbol broadcast_greater_equal(const std::string& symbol_name,
3088  Symbol lhs,
3089  Symbol rhs) {
3090  return Operator("broadcast_greater_equal")
3091  .SetInput("lhs", lhs)
3092  .SetInput("rhs", rhs)
3093  .CreateSymbol(symbol_name);
3094 }
3095 
3118 inline Symbol broadcast_lesser(const std::string& symbol_name,
3119  Symbol lhs,
3120  Symbol rhs) {
3121  return Operator("broadcast_lesser")
3122  .SetInput("lhs", lhs)
3123  .SetInput("rhs", rhs)
3124  .CreateSymbol(symbol_name);
3125 }
3126 
3149 inline Symbol broadcast_lesser_equal(const std::string& symbol_name,
3150  Symbol lhs,
3151  Symbol rhs) {
3152  return Operator("broadcast_lesser_equal")
3153  .SetInput("lhs", lhs)
3154  .SetInput("rhs", rhs)
3155  .CreateSymbol(symbol_name);
3156 }
3157 
3174 inline Symbol where(const std::string& symbol_name,
3175  Symbol condition,
3176  Symbol x,
3177  Symbol y) {
3178  return Operator("where")
3179  .SetInput("condition", condition)
3180  .SetInput("x", x)
3181  .SetInput("y", y)
3182  .CreateSymbol(symbol_name);
3183 }
3184 
3210 inline Symbol smooth_l1(const std::string& symbol_name,
3211  Symbol data,
3212  mx_float scalar) {
3213  return Operator("smooth_l1")
3214  .SetParam("scalar", scalar)
3215  .SetInput("data", data)
3216  .CreateSymbol(symbol_name);
3217 }
3218 
3221 enum class Cast_storageStype {
3222  kCsr = 0,
3223  kDefault = 1,
3224  kRow_sparse = 2
3225 };
3226 
3270 inline Symbol cast_storage(const std::string& symbol_name,
3271  Symbol data,
3272  Cast_storageStype stype) {
3273  static const char *Cast_storageStypeValues[] = {
3274  "csr",
3275  "default",
3276  "row_sparse"
3277  };
3278  return Operator("cast_storage")
3279  .SetParam("stype", Cast_storageStypeValues[int(stype)])
3280  .SetInput("data", data)
3281  .CreateSymbol(symbol_name);
3282 }
3283 
3304 inline Symbol sin(const std::string& symbol_name,
3305  Symbol data) {
3306  return Operator("sin")
3307  .SetInput("data", data)
3308  .CreateSymbol(symbol_name);
3309 }
3310 
3328 inline Symbol cos(const std::string& symbol_name,
3329  Symbol data) {
3330  return Operator("cos")
3331  .SetInput("data", data)
3332  .CreateSymbol(symbol_name);
3333 }
3334 
3355 inline Symbol tan(const std::string& symbol_name,
3356  Symbol data) {
3357  return Operator("tan")
3358  .SetInput("data", data)
3359  .CreateSymbol(symbol_name);
3360 }
3361 
3383 inline Symbol arcsin(const std::string& symbol_name,
3384  Symbol data) {
3385  return Operator("arcsin")
3386  .SetInput("data", data)
3387  .CreateSymbol(symbol_name);
3388 }
3389 
3408 inline Symbol arccos(const std::string& symbol_name,
3409  Symbol data) {
3410  return Operator("arccos")
3411  .SetInput("data", data)
3412  .CreateSymbol(symbol_name);
3413 }
3414 
3435 inline Symbol arctan(const std::string& symbol_name,
3436  Symbol data) {
3437  return Operator("arctan")
3438  .SetInput("data", data)
3439  .CreateSymbol(symbol_name);
3440 }
3441 
3460 inline Symbol degrees(const std::string& symbol_name,
3461  Symbol data) {
3462  return Operator("degrees")
3463  .SetInput("data", data)
3464  .CreateSymbol(symbol_name);
3465 }
3466 
3485 inline Symbol radians(const std::string& symbol_name,
3486  Symbol data) {
3487  return Operator("radians")
3488  .SetInput("data", data)
3489  .CreateSymbol(symbol_name);
3490 }
3491 
3510 inline Symbol sinh(const std::string& symbol_name,
3511  Symbol data) {
3512  return Operator("sinh")
3513  .SetInput("data", data)
3514  .CreateSymbol(symbol_name);
3515 }
3516 
3532 inline Symbol cosh(const std::string& symbol_name,
3533  Symbol data) {
3534  return Operator("cosh")
3535  .SetInput("data", data)
3536  .CreateSymbol(symbol_name);
3537 }
3538 
3557 inline Symbol tanh(const std::string& symbol_name,
3558  Symbol data) {
3559  return Operator("tanh")
3560  .SetInput("data", data)
3561  .CreateSymbol(symbol_name);
3562 }
3563 
3580 inline Symbol arcsinh(const std::string& symbol_name,
3581  Symbol data) {
3582  return Operator("arcsinh")
3583  .SetInput("data", data)
3584  .CreateSymbol(symbol_name);
3585 }
3586 
3600 inline Symbol arccosh(const std::string& symbol_name,
3601  Symbol data) {
3602  return Operator("arccosh")
3603  .SetInput("data", data)
3604  .CreateSymbol(symbol_name);
3605 }
3606 
3623 inline Symbol arctanh(const std::string& symbol_name,
3624  Symbol data) {
3625  return Operator("arctanh")
3626  .SetInput("data", data)
3627  .CreateSymbol(symbol_name);
3628 }
3629 
3645 inline Symbol Custom(const std::string& symbol_name,
3646  const std::vector<Symbol>& data,
3647  const std::string& op_type) {
3648  return Operator("Custom")
3649 (data)
3650  .CreateSymbol(symbol_name);
3651 }
3652 
3681 inline Symbol SwapAxis(const std::string& symbol_name,
3682  Symbol data,
3683  uint32_t dim1 = 0,
3684  uint32_t dim2 = 0) {
3685  return Operator("SwapAxis")
3686  .SetParam("dim1", dim1)
3687  .SetParam("dim2", dim2)
3688  .SetInput("data", data)
3689  .CreateSymbol(symbol_name);
3690 }
3691 
3694 enum class LeakyReLUActType {
3695  kElu = 0,
3696  kLeaky = 1,
3697  kPrelu = 2,
3698  kRrelu = 3
3699 };
3700 
3727 inline Symbol LeakyReLU(const std::string& symbol_name,
3728  Symbol data,
3730  mx_float slope = 0.25,
3731  mx_float lower_bound = 0.125,
3732  mx_float upper_bound = 0.334) {
3733  static const char *LeakyReLUActTypeValues[] = {
3734  "elu",
3735  "leaky",
3736  "prelu",
3737  "rrelu"
3738  };
3739  return Operator("LeakyReLU")
3740  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
3741  .SetParam("slope", slope)
3742  .SetParam("lower_bound", lower_bound)
3743  .SetParam("upper_bound", upper_bound)
3744  .SetInput("data", data)
3745  .CreateSymbol(symbol_name);
3746 }
3747 
3803 inline Symbol BatchNorm_v1(const std::string& symbol_name,
3804  Symbol data,
3805  Symbol gamma,
3806  Symbol beta,
3807  mx_float eps = 0.001,
3808  mx_float momentum = 0.9,
3809  bool fix_gamma = 1,
3810  bool use_global_stats = 0,
3811  bool output_mean_var = 0) {
3812  return Operator("BatchNorm_v1")
3813  .SetParam("eps", eps)
3814  .SetParam("momentum", momentum)
3815  .SetParam("fix_gamma", fix_gamma)
3816  .SetParam("use_global_stats", use_global_stats)
3817  .SetParam("output_mean_var", output_mean_var)
3818  .SetInput("data", data)
3819  .SetInput("gamma", gamma)
3820  .SetInput("beta", beta)
3821  .CreateSymbol(symbol_name);
3822 }
3823 
3865 inline Symbol Concat(const std::string& symbol_name,
3866  const std::vector<Symbol>& data,
3867  int num_args,
3868  int dim = 1) {
3869  return Operator("Concat")
3870  .SetParam("num_args", num_args)
3871  .SetParam("dim", dim)
3872 (data)
3873  .CreateSymbol(symbol_name);
3874 }
3875 
3903 inline Symbol sgd_update(const std::string& symbol_name,
3904  Symbol weight,
3905  Symbol grad,
3906  mx_float lr,
3907  mx_float wd = 0,
3908  mx_float rescale_grad = 1,
3909  mx_float clip_gradient = -1) {
3910  return Operator("sgd_update")
3911  .SetParam("lr", lr)
3912  .SetParam("wd", wd)
3913  .SetParam("rescale_grad", rescale_grad)
3914  .SetParam("clip_gradient", clip_gradient)
3915  .SetInput("weight", weight)
3916  .SetInput("grad", grad)
3917  .CreateSymbol(symbol_name);
3918 }
3919 
3962 inline Symbol sgd_mom_update(const std::string& symbol_name,
3963  Symbol weight,
3964  Symbol grad,
3965  Symbol mom,
3966  mx_float lr,
3967  mx_float momentum = 0,
3968  mx_float wd = 0,
3969  mx_float rescale_grad = 1,
3970  mx_float clip_gradient = -1) {
3971  return Operator("sgd_mom_update")
3972  .SetParam("lr", lr)
3973  .SetParam("momentum", momentum)
3974  .SetParam("wd", wd)
3975  .SetParam("rescale_grad", rescale_grad)
3976  .SetParam("clip_gradient", clip_gradient)
3977  .SetInput("weight", weight)
3978  .SetInput("grad", grad)
3979  .SetInput("mom", mom)
3980  .CreateSymbol(symbol_name);
3981 }
3982 
3997 inline Symbol mp_sgd_update(const std::string& symbol_name,
3998  Symbol weight,
3999  Symbol grad,
4000  Symbol weight32,
4001  mx_float lr,
4002  mx_float wd = 0,
4003  mx_float rescale_grad = 1,
4004  mx_float clip_gradient = -1) {
4005  return Operator("mp_sgd_update")
4006  .SetParam("lr", lr)
4007  .SetParam("wd", wd)
4008  .SetParam("rescale_grad", rescale_grad)
4009  .SetParam("clip_gradient", clip_gradient)
4010  .SetInput("weight", weight)
4011  .SetInput("grad", grad)
4012  .SetInput("weight32", weight32)
4013  .CreateSymbol(symbol_name);
4014 }
4015 
4032 inline Symbol mp_sgd_mom_update(const std::string& symbol_name,
4033  Symbol weight,
4034  Symbol grad,
4035  Symbol mom,
4036  Symbol weight32,
4037  mx_float lr,
4038  mx_float momentum = 0,
4039  mx_float wd = 0,
4040  mx_float rescale_grad = 1,
4041  mx_float clip_gradient = -1) {
4042  return Operator("mp_sgd_mom_update")
4043  .SetParam("lr", lr)
4044  .SetParam("momentum", momentum)
4045  .SetParam("wd", wd)
4046  .SetParam("rescale_grad", rescale_grad)
4047  .SetParam("clip_gradient", clip_gradient)
4048  .SetInput("weight", weight)
4049  .SetInput("grad", grad)
4050  .SetInput("mom", mom)
4051  .SetInput("weight32", weight32)
4052  .CreateSymbol(symbol_name);
4053 }
4054 
4102 inline Symbol adam_update(const std::string& symbol_name,
4103  Symbol weight,
4104  Symbol grad,
4105  Symbol mean,
4106  Symbol var,
4107  mx_float lr,
4108  mx_float beta1 = 0.9,
4109  mx_float beta2 = 0.999,
4110  mx_float epsilon = 1e-08,
4111  mx_float wd = 0,
4112  mx_float rescale_grad = 1,
4113  mx_float clip_gradient = -1) {
4114  return Operator("adam_update")
4115  .SetParam("lr", lr)
4116  .SetParam("beta1", beta1)
4117  .SetParam("beta2", beta2)
4118  .SetParam("epsilon", epsilon)
4119  .SetParam("wd", wd)
4120  .SetParam("rescale_grad", rescale_grad)
4121  .SetParam("clip_gradient", clip_gradient)
4122  .SetInput("weight", weight)
4123  .SetInput("grad", grad)
4124  .SetInput("mean", mean)
4125  .SetInput("var", var)
4126  .CreateSymbol(symbol_name);
4127 }
4128 
4182 inline Symbol rmsprop_update(const std::string& symbol_name,
4183  Symbol weight,
4184  Symbol grad,
4185  Symbol n,
4186  mx_float lr,
4187  mx_float gamma1 = 0.95,
4188  mx_float epsilon = 1e-08,
4189  mx_float wd = 0,
4190  mx_float rescale_grad = 1,
4191  mx_float clip_gradient = -1,
4192  mx_float clip_weights = -1) {
4193  return Operator("rmsprop_update")
4194  .SetParam("lr", lr)
4195  .SetParam("gamma1", gamma1)
4196  .SetParam("epsilon", epsilon)
4197  .SetParam("wd", wd)
4198  .SetParam("rescale_grad", rescale_grad)
4199  .SetParam("clip_gradient", clip_gradient)
4200  .SetParam("clip_weights", clip_weights)
4201  .SetInput("weight", weight)
4202  .SetInput("grad", grad)
4203  .SetInput("n", n)
4204  .CreateSymbol(symbol_name);
4205 }
4206 
4252 inline Symbol rmspropalex_update(const std::string& symbol_name,
4253  Symbol weight,
4254  Symbol grad,
4255  Symbol n,
4256  Symbol g,
4257  Symbol delta,
4258  mx_float lr,
4259  mx_float gamma1 = 0.95,
4260  mx_float gamma2 = 0.9,
4261  mx_float epsilon = 1e-08,
4262  mx_float wd = 0,
4263  mx_float rescale_grad = 1,
4264  mx_float clip_gradient = -1,
4265  mx_float clip_weights = -1) {
4266  return Operator("rmspropalex_update")
4267  .SetParam("lr", lr)
4268  .SetParam("gamma1", gamma1)
4269  .SetParam("gamma2", gamma2)
4270  .SetParam("epsilon", epsilon)
4271  .SetParam("wd", wd)
4272  .SetParam("rescale_grad", rescale_grad)
4273  .SetParam("clip_gradient", clip_gradient)
4274  .SetParam("clip_weights", clip_weights)
4275  .SetInput("weight", weight)
4276  .SetInput("grad", grad)
4277  .SetInput("n", n)
4278  .SetInput("g", g)
4279  .SetInput("delta", delta)
4280  .CreateSymbol(symbol_name);
4281 }
4282 
4322 inline Symbol ftrl_update(const std::string& symbol_name,
4323  Symbol weight,
4324  Symbol grad,
4325  Symbol z,
4326  Symbol n,
4327  mx_float lr,
4328  mx_float lamda1 = 0.01,
4329  mx_float beta = 1,
4330  mx_float wd = 0,
4331  mx_float rescale_grad = 1,
4332  mx_float clip_gradient = -1) {
4333  return Operator("ftrl_update")
4334  .SetParam("lr", lr)
4335  .SetParam("lamda1", lamda1)
4336  .SetParam("beta", beta)
4337  .SetParam("wd", wd)
4338  .SetParam("rescale_grad", rescale_grad)
4339  .SetParam("clip_gradient", clip_gradient)
4340  .SetInput("weight", weight)
4341  .SetInput("grad", grad)
4342  .SetInput("z", z)
4343  .SetInput("n", n)
4344  .CreateSymbol(symbol_name);
4345 }
4346 
4350 enum class PadMode {
4351  kConstant = 0,
4352  kEdge = 1,
4353  kReflect = 2
4354 };
4355 
4452 inline Symbol Pad(const std::string& symbol_name,
4453  Symbol data,
4454  PadMode mode,
4455  Shape pad_width,
4456  double constant_value = 0) {
4457  static const char *PadModeValues[] = {
4458  "constant",
4459  "edge",
4460  "reflect"
4461  };
4462  return Operator("Pad")
4463  .SetParam("mode", PadModeValues[int(mode)])
4464  .SetParam("pad_width", pad_width)
4465  .SetParam("constant_value", constant_value)
4466  .SetInput("data", data)
4467  .CreateSymbol(symbol_name);
4468 }
4469 
4479 inline Symbol IdentityAttachKLSparseReg(const std::string& symbol_name,
4480  Symbol data,
4481  mx_float sparseness_target = 0.1,
4482  mx_float penalty = 0.001,
4483  mx_float momentum = 0.9) {
4484  return Operator("IdentityAttachKLSparseReg")
4485  .SetParam("sparseness_target", sparseness_target)
4486  .SetParam("penalty", penalty)
4487  .SetParam("momentum", momentum)
4488  .SetInput("data", data)
4489  .CreateSymbol(symbol_name);
4490 }
4491 
4563 inline Symbol SliceChannel(const std::string& symbol_name,
4564  Symbol data,
4565  int num_outputs,
4566  int axis = 1,
4567  bool squeeze_axis = 0) {
4568  return Operator("SliceChannel")
4569  .SetParam("num_outputs", num_outputs)
4570  .SetParam("axis", axis)
4571  .SetParam("squeeze_axis", squeeze_axis)
4572  .SetInput("data", data)
4573  .CreateSymbol(symbol_name);
4574 }
4575 
4613 inline Symbol softmax_cross_entropy(const std::string& symbol_name,
4614  Symbol data,
4615  Symbol label) {
4616  return Operator("softmax_cross_entropy")
4617  .SetInput("data", data)
4618  .SetInput("label", label)
4619  .CreateSymbol(symbol_name);
4620 }
4621 
4625  kBilinear = 0,
4626  kNearest = 1
4627 };
4628 
4633  kConcat = 0,
4634  kSum = 1
4635 };
4636 
4652 inline Symbol UpSampling(const std::string& symbol_name,
4653  const std::vector<Symbol>& data,
4654  uint32_t scale,
4655  UpSamplingSampleType sample_type,
4656  int num_args,
4657  uint32_t num_filter = 0,
4659  uint64_t workspace = 512) {
4660  static const char *UpSamplingSampleTypeValues[] = {
4661  "bilinear",
4662  "nearest"
4663  };
4664  static const char *UpSamplingMultiInputModeValues[] = {
4665  "concat",
4666  "sum"
4667  };
4668  return Operator("UpSampling")
4669  .SetParam("scale", scale)
4670  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
4671  .SetParam("num_args", num_args)
4672  .SetParam("num_filter", num_filter)
4673  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
4674  .SetParam("workspace", workspace)
4675 (data)
4676  .CreateSymbol(symbol_name);
4677 }
4678 
4742 inline Symbol BatchNorm(const std::string& symbol_name,
4743  Symbol data,
4744  Symbol gamma,
4745  Symbol beta,
4746  Symbol moving_mean,
4747  Symbol moving_var,
4748  double eps = 0.001,
4749  mx_float momentum = 0.9,
4750  bool fix_gamma = 1,
4751  bool use_global_stats = 0,
4752  bool output_mean_var = 0,
4753  int axis = 1,
4754  bool cudnn_off = 0) {
4755  return Operator("BatchNorm")
4756  .SetParam("eps", eps)
4757  .SetParam("momentum", momentum)
4758  .SetParam("fix_gamma", fix_gamma)
4759  .SetParam("use_global_stats", use_global_stats)
4760  .SetParam("output_mean_var", output_mean_var)
4761  .SetParam("axis", axis)
4762  .SetParam("cudnn_off", cudnn_off)
4763  .SetInput("data", data)
4764  .SetInput("gamma", gamma)
4765  .SetInput("beta", beta)
4766  .SetInput("moving_mean", moving_mean)
4767  .SetInput("moving_var", moving_var)
4768  .CreateSymbol(symbol_name);
4769 }
4770 
4821 inline Symbol InstanceNorm(const std::string& symbol_name,
4822  Symbol data,
4823  Symbol gamma,
4824  Symbol beta,
4825  mx_float eps = 0.001) {
4826  return Operator("InstanceNorm")
4827  .SetParam("eps", eps)
4828  .SetInput("data", data)
4829  .SetInput("gamma", gamma)
4830  .SetInput("beta", beta)
4831  .CreateSymbol(symbol_name);
4832 }
4833 
4836 enum class RNNMode {
4837  kGru = 0,
4838  kLstm = 1,
4839  kRnn_relu = 2,
4840  kRnn_tanh = 3
4841 };
4842 
4858 inline Symbol RNN(const std::string& symbol_name,
4859  Symbol data,
4860  Symbol parameters,
4861  Symbol state,
4862  Symbol state_cell,
4863  uint32_t state_size,
4864  uint32_t num_layers,
4865  RNNMode mode,
4866  bool bidirectional = 0,
4867  mx_float p = 0,
4868  bool state_outputs = 0) {
4869  static const char *RNNModeValues[] = {
4870  "gru",
4871  "lstm",
4872  "rnn_relu",
4873  "rnn_tanh"
4874  };
4875  return Operator("RNN")
4876  .SetParam("state_size", state_size)
4877  .SetParam("num_layers", num_layers)
4878  .SetParam("mode", RNNModeValues[int(mode)])
4879  .SetParam("bidirectional", bidirectional)
4880  .SetParam("p", p)
4881  .SetParam("state_outputs", state_outputs)
4882  .SetInput("data", data)
4883  .SetInput("parameters", parameters)
4884  .SetInput("state", state)
4885  .SetInput("state_cell", state_cell)
4886  .CreateSymbol(symbol_name);
4887 }
4888 
4899  kNone = 0,
4900  kFastest = 1,
4901  kLimited_workspace = 2,
4902  kOff = 3
4903 };
4904 
4909  kNone = 0,
4910  kNCDHW = 1,
4911  kNCHW = 2,
4912  kNDHWC = 3,
4913  kNHWC = 4
4914 };
4915 
4944 inline Symbol Convolution_v1(const std::string& symbol_name,
4945  Symbol data,
4946  Symbol weight,
4947  Symbol bias,
4948  Shape kernel,
4949  uint32_t num_filter,
4950  Shape stride = Shape(),
4951  Shape dilate = Shape(),
4952  Shape pad = Shape(),
4953  uint32_t num_group = 1,
4954  uint64_t workspace = 1024,
4955  bool no_bias = 0,
4957  bool cudnn_off = 0,
4959  static const char *Convolution_v1CudnnTuneValues[] = {
4960  "None",
4961  "fastest",
4962  "limited_workspace",
4963  "off"
4964  };
4965  static const char *Convolution_v1LayoutValues[] = {
4966  "None",
4967  "NCDHW",
4968  "NCHW",
4969  "NDHWC",
4970  "NHWC"
4971  };
4972  return Operator("Convolution_v1")
4973  .SetParam("kernel", kernel)
4974  .SetParam("num_filter", num_filter)
4975  .SetParam("stride", stride)
4976  .SetParam("dilate", dilate)
4977  .SetParam("pad", pad)
4978  .SetParam("num_group", num_group)
4979  .SetParam("workspace", workspace)
4980  .SetParam("no_bias", no_bias)
4981  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
4982  .SetParam("cudnn_off", cudnn_off)
4983  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
4984  .SetInput("data", data)
4985  .SetInput("weight", weight)
4986  .SetInput("bias", bias)
4987  .CreateSymbol(symbol_name);
4988 }
4989 
5010 inline Symbol Crop(const std::string& symbol_name,
5011  const std::vector<Symbol>& data,
5012  int num_args,
5013  Shape offset = Shape(0,0),
5014  Shape h_w = Shape(0,0),
5015  bool center_crop = 0) {
5016  return Operator("Crop")
5017  .SetParam("num_args", num_args)
5018  .SetParam("offset", offset)
5019  .SetParam("h_w", h_w)
5020  .SetParam("center_crop", center_crop)
5021 (data)
5022  .CreateSymbol(symbol_name);
5023 }
5024 
5028  kAffine = 0
5029 };
5030 
5034  kBilinear = 0
5035 };
5036 
5047 inline Symbol SpatialTransformer(const std::string& symbol_name,
5048  Symbol data,
5049  Symbol loc,
5050  SpatialTransformerTransformType transform_type,
5051  SpatialTransformerSamplerType sampler_type,
5052  Shape target_shape = Shape(0,0)) {
5053  static const char *SpatialTransformerTransformTypeValues[] = {
5054  "affine"
5055  };
5056  static const char *SpatialTransformerSamplerTypeValues[] = {
5057  "bilinear"
5058  };
5059  return Operator("SpatialTransformer")
5060  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
5061  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
5062  .SetParam("target_shape", target_shape)
5063  .SetInput("data", data)
5064  .SetInput("loc", loc)
5065  .CreateSymbol(symbol_name);
5066 }
5067 
5071  kNone = 0,
5072  kFastest = 1,
5073  kLimited_workspace = 2,
5074  kOff = 3
5075 };
5076 
5080  kNone = 0,
5081  kNCDHW = 1,
5082  kNCHW = 2,
5083  kNCW = 3,
5084  kNDHWC = 4,
5085  kNHWC = 5
5086 };
5087 
5114 inline Symbol Deconvolution(const std::string& symbol_name,
5115  Symbol data,
5116  Symbol weight,
5117  Symbol bias,
5118  Shape kernel,
5119  uint32_t num_filter,
5120  Shape stride = Shape(),
5121  Shape dilate = Shape(),
5122  Shape pad = Shape(),
5123  Shape adj = Shape(),
5124  Shape target_shape = Shape(),
5125  uint32_t num_group = 1,
5126  uint64_t workspace = 512,
5127  bool no_bias = 1,
5129  bool cudnn_off = 0,
5131  static const char *DeconvolutionCudnnTuneValues[] = {
5132  "None",
5133  "fastest",
5134  "limited_workspace",
5135  "off"
5136  };
5137  static const char *DeconvolutionLayoutValues[] = {
5138  "None",
5139  "NCDHW",
5140  "NCHW",
5141  "NCW",
5142  "NDHWC",
5143  "NHWC"
5144  };
5145  return Operator("Deconvolution")
5146  .SetParam("kernel", kernel)
5147  .SetParam("num_filter", num_filter)
5148  .SetParam("stride", stride)
5149  .SetParam("dilate", dilate)
5150  .SetParam("pad", pad)
5151  .SetParam("adj", adj)
5152  .SetParam("target_shape", target_shape)
5153  .SetParam("num_group", num_group)
5154  .SetParam("workspace", workspace)
5155  .SetParam("no_bias", no_bias)
5156  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
5157  .SetParam("cudnn_off", cudnn_off)
5158  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
5159  .SetInput("data", data)
5160  .SetInput("weight", weight)
5161  .SetInput("bias", bias)
5162  .CreateSymbol(symbol_name);
5163 }
5164 
5168  kBatch = 0,
5169  kNull = 1,
5170  kValid = 2
5171 };
5172 
5268 inline Symbol SoftmaxOutput(const std::string& symbol_name,
5269  Symbol data,
5270  Symbol label,
5271  mx_float grad_scale = 1,
5272  mx_float ignore_label = -1,
5273  bool multi_output = 0,
5274  bool use_ignore = 0,
5275  bool preserve_shape = 0,
5277  bool out_grad = 0,
5278  mx_float smooth_alpha = 0) {
5279  static const char *SoftmaxOutputNormalizationValues[] = {
5280  "batch",
5281  "null",
5282  "valid"
5283  };
5284  return Operator("SoftmaxOutput")
5285  .SetParam("grad_scale", grad_scale)
5286  .SetParam("ignore_label", ignore_label)
5287  .SetParam("multi_output", multi_output)
5288  .SetParam("use_ignore", use_ignore)
5289  .SetParam("preserve_shape", preserve_shape)
5290  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
5291  .SetParam("out_grad", out_grad)
5292  .SetParam("smooth_alpha", smooth_alpha)
5293  .SetInput("data", data)
5294  .SetInput("label", label)
5295  .CreateSymbol(symbol_name);
5296 }
5297 
5301  kBatch = 0,
5302  kNull = 1,
5303  kValid = 2
5304 };
5305 
5333 inline Symbol Softmax(const std::string& symbol_name,
5334  Symbol data,
5335  mx_float grad_scale = 1,
5336  mx_float ignore_label = -1,
5337  bool multi_output = 0,
5338  bool use_ignore = 0,
5339  bool preserve_shape = 0,
5341  bool out_grad = 0,
5342  mx_float smooth_alpha = 0) {
5343  static const char *SoftmaxNormalizationValues[] = {
5344  "batch",
5345  "null",
5346  "valid"
5347  };
5348  return Operator("Softmax")
5349  .SetParam("grad_scale", grad_scale)
5350  .SetParam("ignore_label", ignore_label)
5351  .SetParam("multi_output", multi_output)
5352  .SetParam("use_ignore", use_ignore)
5353  .SetParam("preserve_shape", preserve_shape)
5354  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
5355  .SetParam("out_grad", out_grad)
5356  .SetParam("smooth_alpha", smooth_alpha)
5357  .SetInput("data", data)
5358  .CreateSymbol(symbol_name);
5359 }
5360 
5436 inline Symbol SequenceReverse(const std::string& symbol_name,
5437  Symbol data,
5438  Symbol sequence_length,
5439  bool use_sequence_length = 0) {
5440  return Operator("SequenceReverse")
5441  .SetParam("use_sequence_length", use_sequence_length)
5442  .SetInput("data", data)
5443  .SetInput("sequence_length", sequence_length)
5444  .CreateSymbol(symbol_name);
5445 }
5446 
5501 inline Symbol SequenceLast(const std::string& symbol_name,
5502  Symbol data,
5503  Symbol sequence_length,
5504  bool use_sequence_length = 0) {
5505  return Operator("SequenceLast")
5506  .SetParam("use_sequence_length", use_sequence_length)
5507  .SetInput("data", data)
5508  .SetInput("sequence_length", sequence_length)
5509  .CreateSymbol(symbol_name);
5510 }
5511 
5560 inline Symbol Correlation(const std::string& symbol_name,
5561  Symbol data1,
5562  Symbol data2,
5563  uint32_t kernel_size = 1,
5564  uint32_t max_displacement = 1,
5565  uint32_t stride1 = 1,
5566  uint32_t stride2 = 1,
5567  uint32_t pad_size = 0,
5568  bool is_multiply = 1) {
5569  return Operator("Correlation")
5570  .SetParam("kernel_size", kernel_size)
5571  .SetParam("max_displacement", max_displacement)
5572  .SetParam("stride1", stride1)
5573  .SetParam("stride2", stride2)
5574  .SetParam("pad_size", pad_size)
5575  .SetParam("is_multiply", is_multiply)
5576  .SetInput("data1", data1)
5577  .SetInput("data2", data2)
5578  .CreateSymbol(symbol_name);
5579 }
5580 
5596 inline Symbol SVMOutput(const std::string& symbol_name,
5597  Symbol data,
5598  Symbol label,
5599  mx_float margin = 1,
5600  mx_float regularization_coefficient = 1,
5601  bool use_linear = 0) {
5602  return Operator("SVMOutput")
5603  .SetParam("margin", margin)
5604  .SetParam("regularization_coefficient", regularization_coefficient)
5605  .SetParam("use_linear", use_linear)
5606  .SetInput("data", data)
5607  .SetInput("label", label)
5608  .CreateSymbol(symbol_name);
5609 }
5610 
5614  kChannel = 0,
5615  kInstance = 1,
5616  kSpatial = 2
5617 };
5618 
5681 inline Symbol L2Normalization(const std::string& symbol_name,
5682  Symbol data,
5683  mx_float eps = 1e-10,
5685  static const char *L2NormalizationModeValues[] = {
5686  "channel",
5687  "instance",
5688  "spatial"
5689  };
5690  return Operator("L2Normalization")
5691  .SetParam("eps", eps)
5692  .SetParam("mode", L2NormalizationModeValues[int(mode)])
5693  .SetInput("data", data)
5694  .CreateSymbol(symbol_name);
5695 }
5696 
5724 inline Symbol LRN(const std::string& symbol_name,
5725  Symbol data,
5726  uint32_t nsize,
5727  mx_float alpha = 0.0001,
5728  mx_float beta = 0.75,
5729  mx_float knorm = 2) {
5730  return Operator("LRN")
5731  .SetParam("nsize", nsize)
5732  .SetParam("alpha", alpha)
5733  .SetParam("beta", beta)
5734  .SetParam("knorm", knorm)
5735  .SetInput("data", data)
5736  .CreateSymbol(symbol_name);
5737 }
5738 
5772 inline Symbol FullyConnected(const std::string& symbol_name,
5773  Symbol data,
5774  Symbol weight,
5775  Symbol bias,
5776  int num_hidden,
5777  bool no_bias = 0,
5778  bool flatten = 1) {
5779  return Operator("FullyConnected")
5780  .SetParam("num_hidden", num_hidden)
5781  .SetParam("no_bias", no_bias)
5782  .SetParam("flatten", flatten)
5783  .SetInput("data", data)
5784  .SetInput("weight", weight)
5785  .SetInput("bias", bias)
5786  .CreateSymbol(symbol_name);
5787 }
5788 
5866 inline Symbol SequenceMask(const std::string& symbol_name,
5867  Symbol data,
5868  Symbol sequence_length,
5869  bool use_sequence_length = 0,
5870  mx_float value = 0) {
5871  return Operator("SequenceMask")
5872  .SetParam("use_sequence_length", use_sequence_length)
5873  .SetParam("value", value)
5874  .SetInput("data", data)
5875  .SetInput("sequence_length", sequence_length)
5876  .CreateSymbol(symbol_name);
5877 }
5878 
5883  kAffine = 0,
5884  kWarp = 1
5885 };
5886 
5897 inline Symbol GridGenerator(const std::string& symbol_name,
5898  Symbol data,
5899  GridGeneratorTransformType transform_type,
5900  Shape target_shape = Shape(0,0)) {
5901  static const char *GridGeneratorTransformTypeValues[] = {
5902  "affine",
5903  "warp"
5904  };
5905  return Operator("GridGenerator")
5906  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
5907  .SetParam("target_shape", target_shape)
5908  .SetInput("data", data)
5909  .CreateSymbol(symbol_name);
5910 }
5911 
5915  kAvg = 0,
5916  kMax = 1,
5917  kSum = 2
5918 };
5919 
5923  kFull = 0,
5924  kValid = 1
5925 };
5926 
5978 inline Symbol Pooling_v1(const std::string& symbol_name,
5979  Symbol data,
5980  Shape kernel,
5981  Pooling_v1PoolType pool_type,
5982  bool global_pool = 0,
5984  Shape stride = Shape(),
5985  Shape pad = Shape()) {
5986  static const char *Pooling_v1PoolTypeValues[] = {
5987  "avg",
5988  "max",
5989  "sum"
5990  };
5991  static const char *Pooling_v1PoolingConventionValues[] = {
5992  "full",
5993  "valid"
5994  };
5995  return Operator("Pooling_v1")
5996  .SetParam("kernel", kernel)
5997  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
5998  .SetParam("global_pool", global_pool)
5999  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
6000  .SetParam("stride", stride)
6001  .SetParam("pad", pad)
6002  .SetInput("data", data)
6003  .CreateSymbol(symbol_name);
6004 }
6005 
6009  kNone = 0,
6010  kFastest = 1,
6011  kLimited_workspace = 2,
6012  kOff = 3
6013 };
6014 
6018 enum class ConvolutionLayout {
6019  kNone = 0,
6020  kNCDHW = 1,
6021  kNCHW = 2,
6022  kNCW = 3,
6023  kNDHWC = 4,
6024  kNHWC = 5
6025 };
6026 
6121 inline Symbol Convolution(const std::string& symbol_name,
6122  Symbol data,
6123  Symbol weight,
6124  Symbol bias,
6125  Shape kernel,
6126  uint32_t num_filter,
6127  Shape stride = Shape(),
6128  Shape dilate = Shape(),
6129  Shape pad = Shape(),
6130  uint32_t num_group = 1,
6131  uint64_t workspace = 1024,
6132  bool no_bias = 0,
6134  bool cudnn_off = 0,
6136  static const char *ConvolutionCudnnTuneValues[] = {
6137  "None",
6138  "fastest",
6139  "limited_workspace",
6140  "off"
6141  };
6142  static const char *ConvolutionLayoutValues[] = {
6143  "None",
6144  "NCDHW",
6145  "NCHW",
6146  "NCW",
6147  "NDHWC",
6148  "NHWC"
6149  };
6150  return Operator("Convolution")
6151  .SetParam("kernel", kernel)
6152  .SetParam("num_filter", num_filter)
6153  .SetParam("stride", stride)
6154  .SetParam("dilate", dilate)
6155  .SetParam("pad", pad)
6156  .SetParam("num_group", num_group)
6157  .SetParam("workspace", workspace)
6158  .SetParam("no_bias", no_bias)
6159  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
6160  .SetParam("cudnn_off", cudnn_off)
6161  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
6162  .SetInput("data", data)
6163  .SetInput("weight", weight)
6164  .SetInput("bias", bias)
6165  .CreateSymbol(symbol_name);
6166 }
6167 
6248 inline Symbol BilinearSampler(const std::string& symbol_name,
6249  Symbol data,
6250  Symbol grid) {
6251  return Operator("BilinearSampler")
6252  .SetInput("data", data)
6253  .SetInput("grid", grid)
6254  .CreateSymbol(symbol_name);
6255 }
6256 
6259 enum class PoolingPoolType {
6260  kAvg = 0,
6261  kMax = 1,
6262  kSum = 2
6263 };
6264 
6268  kFull = 0,
6269  kValid = 1
6270 };
6271 
6325 inline Symbol Pooling(const std::string& symbol_name,
6326  Symbol data,
6327  Shape kernel,
6328  PoolingPoolType pool_type,
6329  bool global_pool = 0,
6330  bool cudnn_off = 0,
6332  Shape stride = Shape(),
6333  Shape pad = Shape()) {
6334  static const char *PoolingPoolTypeValues[] = {
6335  "avg",
6336  "max",
6337  "sum"
6338  };
6339  static const char *PoolingPoolingConventionValues[] = {
6340  "full",
6341  "valid"
6342  };
6343  return Operator("Pooling")
6344  .SetParam("kernel", kernel)
6345  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
6346  .SetParam("global_pool", global_pool)
6347  .SetParam("cudnn_off", cudnn_off)
6348  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
6349  .SetParam("stride", stride)
6350  .SetParam("pad", pad)
6351  .SetInput("data", data)
6352  .CreateSymbol(symbol_name);
6353 }
6354 
6357 enum class DropoutMode {
6358  kAlways = 0,
6359  kTraining = 1
6360 };
6361 
6401 inline Symbol Dropout(const std::string& symbol_name,
6402  Symbol data,
6403  mx_float p = 0.5,
6405  static const char *DropoutModeValues[] = {
6406  "always",
6407  "training"
6408  };
6409  return Operator("Dropout")
6410  .SetParam("p", p)
6411  .SetParam("mode", DropoutModeValues[int(mode)])
6412  .SetInput("data", data)
6413  .CreateSymbol(symbol_name);
6414 }
6415 
6418 enum class ActivationActType {
6419  kRelu = 0,
6420  kSigmoid = 1,
6421  kSoftrelu = 2,
6422  kTanh = 3
6423 };
6424 
6443 inline Symbol Activation(const std::string& symbol_name,
6444  Symbol data,
6445  ActivationActType act_type) {
6446  static const char *ActivationActTypeValues[] = {
6447  "relu",
6448  "sigmoid",
6449  "softrelu",
6450  "tanh"
6451  };
6452  return Operator("Activation")
6453  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
6454  .SetInput("data", data)
6455  .CreateSymbol(symbol_name);
6456 }
6457 
6514 inline Symbol ROIPooling(const std::string& symbol_name,
6515  Symbol data,
6516  Symbol rois,
6517  Shape pooled_size,
6518  mx_float spatial_scale) {
6519  return Operator("ROIPooling")
6520  .SetParam("pooled_size", pooled_size)
6521  .SetParam("spatial_scale", spatial_scale)
6522  .SetInput("data", data)
6523  .SetInput("rois", rois)
6524  .CreateSymbol(symbol_name);
6525 }
6526 
6551 inline Symbol LinearRegressionOutput(const std::string& symbol_name,
6552  Symbol data,
6553  Symbol label,
6554  mx_float grad_scale = 1) {
6555  return Operator("LinearRegressionOutput")
6556  .SetParam("grad_scale", grad_scale)
6557  .SetInput("data", data)
6558  .SetInput("label", label)
6559  .CreateSymbol(symbol_name);
6560 }
6561 
6587 inline Symbol MAERegressionOutput(const std::string& symbol_name,
6588  Symbol data,
6589  Symbol label,
6590  mx_float grad_scale = 1) {
6591  return Operator("MAERegressionOutput")
6592  .SetParam("grad_scale", grad_scale)
6593  .SetInput("data", data)
6594  .SetInput("label", label)
6595  .CreateSymbol(symbol_name);
6596 }
6597 
6623 inline Symbol LogisticRegressionOutput(const std::string& symbol_name,
6624  Symbol data,
6625  Symbol label,
6626  mx_float grad_scale = 1) {
6627  return Operator("LogisticRegressionOutput")
6628  .SetParam("grad_scale", grad_scale)
6629  .SetInput("data", data)
6630  .SetInput("label", label)
6631  .CreateSymbol(symbol_name);
6632 }
6633 
6638  kChannel = 0,
6639  kInstance = 1
6640 };
6641 
6675 inline Symbol SoftmaxActivation(const std::string& symbol_name,
6676  Symbol data,
6678  static const char *SoftmaxActivationModeValues[] = {
6679  "channel",
6680  "instance"
6681  };
6682  return Operator("SoftmaxActivation")
6683  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
6684  .SetInput("data", data)
6685  .CreateSymbol(symbol_name);
6686 }
6687 
6693  kBatch = 0,
6694  kNull = 1,
6695  kValid = 2
6696 };
6697 
6732 inline Symbol MakeLoss(const std::string& symbol_name,
6733  Symbol data,
6734  mx_float grad_scale = 1,
6735  mx_float valid_thresh = 0,
6737  static const char *MakeLossNormalizationValues[] = {
6738  "batch",
6739  "null",
6740  "valid"
6741  };
6742  return Operator("MakeLoss")
6743  .SetParam("grad_scale", grad_scale)
6744  .SetParam("valid_thresh", valid_thresh)
6745  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
6746  .SetInput("data", data)
6747  .CreateSymbol(symbol_name);
6748 }
6749 
6758 inline Symbol choose_element_0index(const std::string& symbol_name,
6759  Symbol lhs,
6760  Symbol rhs) {
6761  return Operator("choose_element_0index")
6762  .SetInput("lhs", lhs)
6763  .SetInput("rhs", rhs)
6764  .CreateSymbol(symbol_name);
6765 }
6766 
6776 inline Symbol fill_element_0index(const std::string& symbol_name,
6777  Symbol lhs,
6778  Symbol mhs,
6779  Symbol rhs) {
6780  return Operator("fill_element_0index")
6781  .SetInput("lhs", lhs)
6782  .SetInput("mhs", mhs)
6783  .SetInput("rhs", rhs)
6784  .CreateSymbol(symbol_name);
6785 }
6786 
6815 inline Symbol softmax(Symbol data,
6816  int axis = -1) {
6817  return Operator("softmax")
6818  .SetParam("axis", axis)
6819  .SetInput("data", data)
6820  .CreateSymbol();
6821 }
6822 
6845  int axis = -1) {
6846  return Operator("log_softmax")
6847  .SetParam("axis", axis)
6848  .SetInput("data", data)
6849  .CreateSymbol();
6850 }
6851 
6874  Symbol rhs) {
6875  return Operator("broadcast_power")
6876  .SetInput("lhs", lhs)
6877  .SetInput("rhs", rhs)
6878  .CreateSymbol();
6879 }
6880 
6905  Symbol rhs) {
6906  return Operator("broadcast_maximum")
6907  .SetInput("lhs", lhs)
6908  .SetInput("rhs", rhs)
6909  .CreateSymbol();
6910 }
6911 
6936  Symbol rhs) {
6937  return Operator("broadcast_minimum")
6938  .SetInput("lhs", lhs)
6939  .SetInput("rhs", rhs)
6940  .CreateSymbol();
6941 }
6942 
6973  Symbol rhs) {
6974  return Operator("broadcast_hypot")
6975  .SetInput("lhs", lhs)
6976  .SetInput("rhs", rhs)
6977  .CreateSymbol();
6978 }
6979 
7053 inline Symbol Reshape(Symbol data,
7054  Shape shape = Shape(),
7055  bool reverse = 0,
7056  Shape target_shape = Shape(),
7057  bool keep_highest = 0) {
7058  return Operator("Reshape")
7059  .SetParam("shape", shape)
7060  .SetParam("reverse", reverse)
7061  .SetParam("target_shape", target_shape)
7062  .SetParam("keep_highest", keep_highest)
7063  .SetInput("data", data)
7064  .CreateSymbol();
7065 }
7066 
7096 inline Symbol Flatten(Symbol data) {
7097  return Operator("Flatten")
7098  .SetInput("data", data)
7099  .CreateSymbol();
7100 }
7101 
7138  Shape axes = Shape()) {
7139  return Operator("transpose")
7140  .SetParam("axes", axes)
7141  .SetInput("data", data)
7142  .CreateSymbol();
7143 }
7144 
7160  int axis) {
7161  return Operator("expand_dims")
7162  .SetParam("axis", axis)
7163  .SetInput("data", data)
7164  .CreateSymbol();
7165 }
7166 
7202 inline Symbol slice(Symbol data,
7203  Shape begin,
7204  Shape end) {
7205  return Operator("slice")
7206  .SetParam("begin", begin)
7207  .SetParam("end", end)
7208  .SetInput("data", data)
7209  .CreateSymbol();
7210 }
7211 
7244  int axis,
7245  int begin,
7246  dmlc::optional<int> end) {
7247  return Operator("slice_axis")
7248  .SetParam("axis", axis)
7249  .SetParam("begin", begin)
7250  .SetParam("end", end)
7251  .SetInput("data", data)
7252  .CreateSymbol();
7253 }
7254 
7288 inline Symbol clip(Symbol data,
7289  mx_float a_min,
7290  mx_float a_max) {
7291  return Operator("clip")
7292  .SetParam("a_min", a_min)
7293  .SetParam("a_max", a_max)
7294  .SetInput("data", data)
7295  .CreateSymbol();
7296 }
7297 
7331 inline Symbol repeat(Symbol data,
7332  int repeats,
7333  dmlc::optional<int> axis = dmlc::optional<int>()) {
7334  return Operator("repeat")
7335  .SetParam("repeats", repeats)
7336  .SetParam("axis", axis)
7337  .SetInput("data", data)
7338  .CreateSymbol();
7339 }
7340 
7385 inline Symbol tile(Symbol data,
7386  Shape reps) {
7387  return Operator("tile")
7388  .SetParam("reps", reps)
7389  .SetInput("data", data)
7390  .CreateSymbol();
7391 }
7392 
7415 inline Symbol reverse(Symbol data,
7416  Shape axis) {
7417  return Operator("reverse")
7418  .SetParam("axis", axis)
7419  .SetInput("data", data)
7420  .CreateSymbol();
7421 }
7422 
7445 inline Symbol stack(const std::vector<Symbol>& data,
7446  int num_args,
7447  int axis = 0) {
7448  return Operator("stack")
7449  .SetParam("num_args", num_args)
7450  .SetParam("axis", axis)
7451 (data)
7452  .CreateSymbol();
7453 }
7454 
7477 inline Symbol zeros_like(Symbol data) {
7478  return Operator("zeros_like")
7479  .SetInput("data", data)
7480  .CreateSymbol();
7481 }
7482 
7499 inline Symbol ones_like(Symbol data) {
7500  return Operator("ones_like")
7501  .SetInput("data", data)
7502  .CreateSymbol();
7503 }
7504 
7532  Symbol rhs) {
7533  return Operator("broadcast_add")
7534  .SetInput("lhs", lhs)
7535  .SetInput("rhs", rhs)
7536  .CreateSymbol();
7537 }
7538 
7566  Symbol rhs) {
7567  return Operator("broadcast_sub")
7568  .SetInput("lhs", lhs)
7569  .SetInput("rhs", rhs)
7570  .CreateSymbol();
7571 }
7572 
7595  Symbol rhs) {
7596  return Operator("broadcast_mul")
7597  .SetInput("lhs", lhs)
7598  .SetInput("rhs", rhs)
7599  .CreateSymbol();
7600 }
7601 
7624  Symbol rhs) {
7625  return Operator("broadcast_div")
7626  .SetInput("lhs", lhs)
7627  .SetInput("rhs", rhs)
7628  .CreateSymbol();
7629 }
7630 
7653  Symbol rhs) {
7654  return Operator("broadcast_mod")
7655  .SetInput("lhs", lhs)
7656  .SetInput("rhs", rhs)
7657  .CreateSymbol();
7658 }
7659 
7679 inline Symbol add_n(const std::vector<Symbol>& args) {
7680  return Operator("add_n")
7681 (args)
7682  .CreateSymbol();
7683 }
7684 
7715 inline Symbol argmax(Symbol data,
7716  dmlc::optional<int> axis = dmlc::optional<int>(),
7717  bool keepdims = 0) {
7718  return Operator("argmax")
7719  .SetParam("axis", axis)
7720  .SetParam("keepdims", keepdims)
7721  .SetInput("data", data)
7722  .CreateSymbol();
7723 }
7724 
7755 inline Symbol argmin(Symbol data,
7756  dmlc::optional<int> axis = dmlc::optional<int>(),
7757  bool keepdims = 0) {
7758  return Operator("argmin")
7759  .SetParam("axis", axis)
7760  .SetParam("keepdims", keepdims)
7761  .SetInput("data", data)
7762  .CreateSymbol();
7763 }
7764 
7787  return Operator("argmax_channel")
7788  .SetInput("data", data)
7789  .CreateSymbol();
7790 }
7791 
7836 inline Symbol pick(Symbol data,
7837  Symbol index,
7838  dmlc::optional<int> axis = dmlc::optional<int>(),
7839  bool keepdims = 0) {
7840  return Operator("pick")
7841  .SetParam("axis", axis)
7842  .SetParam("keepdims", keepdims)
7843  .SetInput("data", data)
7844  .SetInput("index", index)
7845  .CreateSymbol();
7846 }
7847 
7886 inline Symbol dot(Symbol lhs,
7887  Symbol rhs,
7888  bool transpose_a = 0,
7889  bool transpose_b = 0) {
7890  return Operator("dot")
7891  .SetParam("transpose_a", transpose_a)
7892  .SetParam("transpose_b", transpose_b)
7893  .SetInput("lhs", lhs)
7894  .SetInput("rhs", rhs)
7895  .CreateSymbol();
7896 }
7897 
7920  Symbol rhs,
7921  bool transpose_a = 0,
7922  bool transpose_b = 0) {
7923  return Operator("batch_dot")
7924  .SetParam("transpose_a", transpose_a)
7925  .SetParam("transpose_b", transpose_b)
7926  .SetInput("lhs", lhs)
7927  .SetInput("rhs", rhs)
7928  .CreateSymbol();
7929 }
7930 
7948 inline Symbol relu(Symbol data) {
7949  return Operator("relu")
7950  .SetInput("data", data)
7951  .CreateSymbol();
7952 }
7953 
7968 inline Symbol sigmoid(Symbol data) {
7969  return Operator("sigmoid")
7970  .SetInput("data", data)
7971  .CreateSymbol();
7972 }
7973 
8006 inline Symbol BlockGrad(Symbol data) {
8007  return Operator("BlockGrad")
8008  .SetInput("data", data)
8009  .CreateSymbol();
8010 }
8011 
8040 inline Symbol make_loss(Symbol data) {
8041  return Operator("make_loss")
8042  .SetInput("data", data)
8043  .CreateSymbol();
8044 }
8045 
8053  Symbol rhs) {
8054  return Operator("reshape_like")
8055  .SetInput("lhs", lhs)
8056  .SetInput("rhs", rhs)
8057  .CreateSymbol();
8058 }
8059 
8078 inline Symbol Cast(Symbol data,
8079  CastDtype dtype) {
8080  static const char *CastDtypeValues[] = {
8081  "float16",
8082  "float32",
8083  "float64",
8084  "int32",
8085  "uint8"
8086  };
8087  return Operator("Cast")
8088  .SetParam("dtype", CastDtypeValues[int(dtype)])
8089  .SetInput("data", data)
8090  .CreateSymbol();
8091 }
8092 
8106 inline Symbol negative(Symbol data) {
8107  return Operator("negative")
8108  .SetInput("data", data)
8109  .CreateSymbol();
8110 }
8111 
8127 inline Symbol reciprocal(Symbol data) {
8128  return Operator("reciprocal")
8129  .SetInput("data", data)
8130  .CreateSymbol();
8131 }
8132 
8151 inline Symbol abs(Symbol data) {
8152  return Operator("abs")
8153  .SetInput("data", data)
8154  .CreateSymbol();
8155 }
8156 
8175 inline Symbol sign(Symbol data) {
8176  return Operator("sign")
8177  .SetInput("data", data)
8178  .CreateSymbol();
8179 }
8180 
8199 inline Symbol round(Symbol data) {
8200  return Operator("round")
8201  .SetInput("data", data)
8202  .CreateSymbol();
8203 }
8204 
8227 inline Symbol rint(Symbol data) {
8228  return Operator("rint")
8229  .SetInput("data", data)
8230  .CreateSymbol();
8231 }
8232 
8253 inline Symbol ceil(Symbol data) {
8254  return Operator("ceil")
8255  .SetInput("data", data)
8256  .CreateSymbol();
8257 }
8258 
8279 inline Symbol floor(Symbol data) {
8280  return Operator("floor")
8281  .SetInput("data", data)
8282  .CreateSymbol();
8283 }
8284 
8306 inline Symbol trunc(Symbol data) {
8307  return Operator("trunc")
8308  .SetInput("data", data)
8309  .CreateSymbol();
8310 }
8311 
8331 inline Symbol fix(Symbol data) {
8332  return Operator("fix")
8333  .SetInput("data", data)
8334  .CreateSymbol();
8335 }
8336 
8359 inline Symbol square(Symbol data) {
8360  return Operator("square")
8361  .SetInput("data", data)
8362  .CreateSymbol();
8363 }
8364 
8386 inline Symbol sqrt(Symbol data) {
8387  return Operator("sqrt")
8388  .SetInput("data", data)
8389  .CreateSymbol();
8390 }
8391 
8410 inline Symbol rsqrt(Symbol data) {
8411  return Operator("rsqrt")
8412  .SetInput("data", data)
8413  .CreateSymbol();
8414 }
8415 
8432 inline Symbol cbrt(Symbol data) {
8433  return Operator("cbrt")
8434  .SetInput("data", data)
8435  .CreateSymbol();
8436 }
8437 
8454 inline Symbol rcbrt(Symbol data) {
8455  return Operator("rcbrt")
8456  .SetInput("data", data)
8457  .CreateSymbol();
8458 }
8459 
8478 inline Symbol exp(Symbol data) {
8479  return Operator("exp")
8480  .SetInput("data", data)
8481  .CreateSymbol();
8482 }
8483 
8497 inline Symbol log(Symbol data) {
8498  return Operator("log")
8499  .SetInput("data", data)
8500  .CreateSymbol();
8501 }
8502 
8516 inline Symbol log10(Symbol data) {
8517  return Operator("log10")
8518  .SetInput("data", data)
8519  .CreateSymbol();
8520 }
8521 
8535 inline Symbol log2(Symbol data) {
8536  return Operator("log2")
8537  .SetInput("data", data)
8538  .CreateSymbol();
8539 }
8540 
8558 inline Symbol log1p(Symbol data) {
8559  return Operator("log1p")
8560  .SetInput("data", data)
8561  .CreateSymbol();
8562 }
8563 
8580 inline Symbol expm1(Symbol data) {
8581  return Operator("expm1")
8582  .SetInput("data", data)
8583  .CreateSymbol();
8584 }
8585 
8596 inline Symbol gamma(Symbol data) {
8597  return Operator("gamma")
8598  .SetInput("data", data)
8599  .CreateSymbol();
8600 }
8601 
8612 inline Symbol gammaln(Symbol data) {
8613  return Operator("gammaln")
8614  .SetInput("data", data)
8615  .CreateSymbol();
8616 }
8617 
8675 inline Symbol sum(Symbol data,
8676  Shape axis = Shape(),
8677  bool keepdims = 0,
8678  bool exclude = 0) {
8679  return Operator("sum")
8680  .SetParam("axis", axis)
8681  .SetParam("keepdims", keepdims)
8682  .SetParam("exclude", exclude)
8683  .SetInput("data", data)
8684  .CreateSymbol();
8685 }
8686 
8710 inline Symbol mean(Symbol data,
8711  Shape axis = Shape(),
8712  bool keepdims = 0,
8713  bool exclude = 0) {
8714  return Operator("mean")
8715  .SetParam("axis", axis)
8716  .SetParam("keepdims", keepdims)
8717  .SetParam("exclude", exclude)
8718  .SetInput("data", data)
8719  .CreateSymbol();
8720 }
8721 
8745 inline Symbol prod(Symbol data,
8746  Shape axis = Shape(),
8747  bool keepdims = 0,
8748  bool exclude = 0) {
8749  return Operator("prod")
8750  .SetParam("axis", axis)
8751  .SetParam("keepdims", keepdims)
8752  .SetParam("exclude", exclude)
8753  .SetInput("data", data)
8754  .CreateSymbol();
8755 }
8756 
8782 inline Symbol nansum(Symbol data,
8783  Shape axis = Shape(),
8784  bool keepdims = 0,
8785  bool exclude = 0) {
8786  return Operator("nansum")
8787  .SetParam("axis", axis)
8788  .SetParam("keepdims", keepdims)
8789  .SetParam("exclude", exclude)
8790  .SetInput("data", data)
8791  .CreateSymbol();
8792 }
8793 
8819 inline Symbol nanprod(Symbol data,
8820  Shape axis = Shape(),
8821  bool keepdims = 0,
8822  bool exclude = 0) {
8823  return Operator("nanprod")
8824  .SetParam("axis", axis)
8825  .SetParam("keepdims", keepdims)
8826  .SetParam("exclude", exclude)
8827  .SetInput("data", data)
8828  .CreateSymbol();
8829 }
8830 
8854 inline Symbol max(Symbol data,
8855  Shape axis = Shape(),
8856  bool keepdims = 0,
8857  bool exclude = 0) {
8858  return Operator("max")
8859  .SetParam("axis", axis)
8860  .SetParam("keepdims", keepdims)
8861  .SetParam("exclude", exclude)
8862  .SetInput("data", data)
8863  .CreateSymbol();
8864 }
8865 
8889 inline Symbol min(Symbol data,
8890  Shape axis = Shape(),
8891  bool keepdims = 0,
8892  bool exclude = 0) {
8893  return Operator("min")
8894  .SetParam("axis", axis)
8895  .SetParam("keepdims", keepdims)
8896  .SetParam("exclude", exclude)
8897  .SetInput("data", data)
8898  .CreateSymbol();
8899 }
8900 
8930  Shape axis = Shape(),
8931  Shape size = Shape()) {
8932  return Operator("broadcast_axis")
8933  .SetParam("axis", axis)
8934  .SetParam("size", size)
8935  .SetInput("data", data)
8936  .CreateSymbol();
8937 }
8938 
8967  Shape shape = Shape()) {
8968  return Operator("broadcast_to")
8969  .SetParam("shape", shape)
8970  .SetInput("data", data)
8971  .CreateSymbol();
8972 }
8973 
8990 inline Symbol norm(Symbol data) {
8991  return Operator("norm")
8992  .SetInput("data", data)
8993  .CreateSymbol();
8994 }
8995 
9036 inline Symbol topk(Symbol data,
9037  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9038  int k = 1,
9039  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
9040  bool is_ascend = 0) {
9041  static const char *TopkRetTypValues[] = {
9042  "both",
9043  "indices",
9044  "mask",
9045  "value"
9046  };
9047  return Operator("topk")
9048  .SetParam("axis", axis)
9049  .SetParam("k", k)
9050  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
9051  .SetParam("is_ascend", is_ascend)
9052  .SetInput("data", data)
9053  .CreateSymbol();
9054 }
9055 
9087 inline Symbol sort(Symbol data,
9088  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9089  bool is_ascend = 1) {
9090  return Operator("sort")
9091  .SetParam("axis", axis)
9092  .SetParam("is_ascend", is_ascend)
9093  .SetInput("data", data)
9094  .CreateSymbol();
9095 }
9096 
9126 inline Symbol argsort(Symbol data,
9127  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9128  bool is_ascend = 1) {
9129  return Operator("argsort")
9130  .SetParam("axis", axis)
9131  .SetParam("is_ascend", is_ascend)
9132  .SetInput("data", data)
9133  .CreateSymbol();
9134 }
9135 
9150  Symbol rhs) {
9151  return Operator("elemwise_add")
9152  .SetInput("lhs", lhs)
9153  .SetInput("rhs", rhs)
9154  .CreateSymbol();
9155 }
9156 
9171  Symbol rhs) {
9172  return Operator("elemwise_sub")
9173  .SetInput("lhs", lhs)
9174  .SetInput("rhs", rhs)
9175  .CreateSymbol();
9176 }
9177 
9195  Symbol rhs) {
9196  return Operator("elemwise_mul")
9197  .SetInput("lhs", lhs)
9198  .SetInput("rhs", rhs)
9199  .CreateSymbol();
9200 }
9201 
9213  Symbol rhs) {
9214  return Operator("elemwise_div")
9215  .SetInput("lhs", lhs)
9216  .SetInput("rhs", rhs)
9217  .CreateSymbol();
9218 }
9219 
9271  Symbol weight,
9272  int input_dim,
9273  int output_dim,
9275  static const char *EmbeddingDtypeValues[] = {
9276  "float16",
9277  "float32",
9278  "float64",
9279  "int32",
9280  "uint8"
9281  };
9282  return Operator("Embedding")
9283  .SetParam("input_dim", input_dim)
9284  .SetParam("output_dim", output_dim)
9285  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
9286  .SetInput("data", data)
9287  .SetInput("weight", weight)
9288  .CreateSymbol();
9289 }
9290 
9329 inline Symbol take(Symbol a,
9330  Symbol indices,
9331  int axis = 0,
9332  TakeMode mode = TakeMode::kClip) {
9333  static const char *TakeModeValues[] = {
9334  "clip",
9335  "raise",
9336  "wrap"
9337  };
9338  return Operator("take")
9339  .SetParam("axis", axis)
9340  .SetParam("mode", TakeModeValues[int(mode)])
9341  .SetInput("a", a)
9342  .SetInput("indices", indices)
9343  .CreateSymbol();
9344 }
9345 
9374  Symbol indices) {
9375  return Operator("batch_take")
9376  .SetInput("a", a)
9377  .SetInput("indices", indices)
9378  .CreateSymbol();
9379 }
9380 
9424 inline Symbol one_hot(Symbol indices,
9425  int depth,
9426  double on_value = 1,
9427  double off_value = 0,
9429  static const char *One_hotDtypeValues[] = {
9430  "float16",
9431  "float32",
9432  "float64",
9433  "int32",
9434  "uint8"
9435  };
9436  return Operator("one_hot")
9437  .SetParam("depth", depth)
9438  .SetParam("on_value", on_value)
9439  .SetParam("off_value", off_value)
9440  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
9441  .SetInput("indices", indices)
9442  .CreateSymbol();
9443 }
9444 
9473  Symbol indices) {
9474  return Operator("gather_nd")
9475  .SetInput("data", data)
9476  .SetInput("indices", indices)
9477  .CreateSymbol();
9478 }
9479 
9510  Symbol indices,
9511  Shape shape) {
9512  return Operator("scatter_nd")
9513  .SetParam("shape", shape)
9514  .SetInput("data", data)
9515  .SetInput("indices", indices)
9516  .CreateSymbol();
9517 }
9518 
9541  Symbol rhs) {
9542  return Operator("broadcast_equal")
9543  .SetInput("lhs", lhs)
9544  .SetInput("rhs", rhs)
9545  .CreateSymbol();
9546 }
9547 
9570  Symbol rhs) {
9571  return Operator("broadcast_not_equal")
9572  .SetInput("lhs", lhs)
9573  .SetInput("rhs", rhs)
9574  .CreateSymbol();
9575 }
9576 
9599  Symbol rhs) {
9600  return Operator("broadcast_greater")
9601  .SetInput("lhs", lhs)
9602  .SetInput("rhs", rhs)
9603  .CreateSymbol();
9604 }
9605 
9628  Symbol rhs) {
9629  return Operator("broadcast_greater_equal")
9630  .SetInput("lhs", lhs)
9631  .SetInput("rhs", rhs)
9632  .CreateSymbol();
9633 }
9634 
9657  Symbol rhs) {
9658  return Operator("broadcast_lesser")
9659  .SetInput("lhs", lhs)
9660  .SetInput("rhs", rhs)
9661  .CreateSymbol();
9662 }
9663 
9686  Symbol rhs) {
9687  return Operator("broadcast_lesser_equal")
9688  .SetInput("lhs", lhs)
9689  .SetInput("rhs", rhs)
9690  .CreateSymbol();
9691 }
9692 
9708 inline Symbol where(Symbol condition,
9709  Symbol x,
9710  Symbol y) {
9711  return Operator("where")
9712  .SetInput("condition", condition)
9713  .SetInput("x", x)
9714  .SetInput("y", y)
9715  .CreateSymbol();
9716 }
9717 
9743  mx_float scalar) {
9744  return Operator("smooth_l1")
9745  .SetParam("scalar", scalar)
9746  .SetInput("data", data)
9747  .CreateSymbol();
9748 }
9749 
9793  Cast_storageStype stype) {
9794  static const char *Cast_storageStypeValues[] = {
9795  "csr",
9796  "default",
9797  "row_sparse"
9798  };
9799  return Operator("cast_storage")
9800  .SetParam("stype", Cast_storageStypeValues[int(stype)])
9801  .SetInput("data", data)
9802  .CreateSymbol();
9803 }
9804 
9824 inline Symbol sin(Symbol data) {
9825  return Operator("sin")
9826  .SetInput("data", data)
9827  .CreateSymbol();
9828 }
9829 
9846 inline Symbol cos(Symbol data) {
9847  return Operator("cos")
9848  .SetInput("data", data)
9849  .CreateSymbol();
9850 }
9851 
9871 inline Symbol tan(Symbol data) {
9872  return Operator("tan")
9873  .SetInput("data", data)
9874  .CreateSymbol();
9875 }
9876 
9897 inline Symbol arcsin(Symbol data) {
9898  return Operator("arcsin")
9899  .SetInput("data", data)
9900  .CreateSymbol();
9901 }
9902 
9920 inline Symbol arccos(Symbol data) {
9921  return Operator("arccos")
9922  .SetInput("data", data)
9923  .CreateSymbol();
9924 }
9925 
9945 inline Symbol arctan(Symbol data) {
9946  return Operator("arctan")
9947  .SetInput("data", data)
9948  .CreateSymbol();
9949 }
9950 
9968 inline Symbol degrees(Symbol data) {
9969  return Operator("degrees")
9970  .SetInput("data", data)
9971  .CreateSymbol();
9972 }
9973 
9991 inline Symbol radians(Symbol data) {
9992  return Operator("radians")
9993  .SetInput("data", data)
9994  .CreateSymbol();
9995 }
9996 
10014 inline Symbol sinh(Symbol data) {
10015  return Operator("sinh")
10016  .SetInput("data", data)
10017  .CreateSymbol();
10018 }
10019 
10034 inline Symbol cosh(Symbol data) {
10035  return Operator("cosh")
10036  .SetInput("data", data)
10037  .CreateSymbol();
10038 }
10039 
10057 inline Symbol tanh(Symbol data) {
10058  return Operator("tanh")
10059  .SetInput("data", data)
10060  .CreateSymbol();
10061 }
10062 
10078 inline Symbol arcsinh(Symbol data) {
10079  return Operator("arcsinh")
10080  .SetInput("data", data)
10081  .CreateSymbol();
10082 }
10083 
10096 inline Symbol arccosh(Symbol data) {
10097  return Operator("arccosh")
10098  .SetInput("data", data)
10099  .CreateSymbol();
10100 }
10101 
10117 inline Symbol arctanh(Symbol data) {
10118  return Operator("arctanh")
10119  .SetInput("data", data)
10120  .CreateSymbol();
10121 }
10122 
10137 inline Symbol Custom(const std::vector<Symbol>& data,
10138  const std::string& op_type) {
10139  return Operator("Custom")
10140 (data)
10141  .CreateSymbol();
10142 }
10143 
10171 inline Symbol SwapAxis(Symbol data,
10172  uint32_t dim1 = 0,
10173  uint32_t dim2 = 0) {
10174  return Operator("SwapAxis")
10175  .SetParam("dim1", dim1)
10176  .SetParam("dim2", dim2)
10177  .SetInput("data", data)
10178  .CreateSymbol();
10179 }
10180 
10208  mx_float slope = 0.25,
10209  mx_float lower_bound = 0.125,
10210  mx_float upper_bound = 0.334) {
10211  static const char *LeakyReLUActTypeValues[] = {
10212  "elu",
10213  "leaky",
10214  "prelu",
10215  "rrelu"
10216  };
10217  return Operator("LeakyReLU")
10218  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
10219  .SetParam("slope", slope)
10220  .SetParam("lower_bound", lower_bound)
10221  .SetParam("upper_bound", upper_bound)
10222  .SetInput("data", data)
10223  .CreateSymbol();
10224 }
10225 
10281  Symbol gamma,
10282  Symbol beta,
10283  mx_float eps = 0.001,
10284  mx_float momentum = 0.9,
10285  bool fix_gamma = 1,
10286  bool use_global_stats = 0,
10287  bool output_mean_var = 0) {
10288  return Operator("BatchNorm_v1")
10289  .SetParam("eps", eps)
10290  .SetParam("momentum", momentum)
10291  .SetParam("fix_gamma", fix_gamma)
10292  .SetParam("use_global_stats", use_global_stats)
10293  .SetParam("output_mean_var", output_mean_var)
10294  .SetInput("data", data)
10295  .SetInput("gamma", gamma)
10296  .SetInput("beta", beta)
10297  .CreateSymbol();
10298 }
10299 
10340 inline Symbol Concat(const std::vector<Symbol>& data,
10341  int num_args,
10342  int dim = 1) {
10343  return Operator("Concat")
10344  .SetParam("num_args", num_args)
10345  .SetParam("dim", dim)
10346 (data)
10347  .CreateSymbol();
10348 }
10349 
10376 inline Symbol sgd_update(Symbol weight,
10377  Symbol grad,
10378  mx_float lr,
10379  mx_float wd = 0,
10380  mx_float rescale_grad = 1,
10381  mx_float clip_gradient = -1) {
10382  return Operator("sgd_update")
10383  .SetParam("lr", lr)
10384  .SetParam("wd", wd)
10385  .SetParam("rescale_grad", rescale_grad)
10386  .SetParam("clip_gradient", clip_gradient)
10387  .SetInput("weight", weight)
10388  .SetInput("grad", grad)
10389  .CreateSymbol();
10390 }
10391 
10434  Symbol grad,
10435  Symbol mom,
10436  mx_float lr,
10437  mx_float momentum = 0,
10438  mx_float wd = 0,
10439  mx_float rescale_grad = 1,
10440  mx_float clip_gradient = -1) {
10441  return Operator("sgd_mom_update")
10442  .SetParam("lr", lr)
10443  .SetParam("momentum", momentum)
10444  .SetParam("wd", wd)
10445  .SetParam("rescale_grad", rescale_grad)
10446  .SetParam("clip_gradient", clip_gradient)
10447  .SetInput("weight", weight)
10448  .SetInput("grad", grad)
10449  .SetInput("mom", mom)
10450  .CreateSymbol();
10451 }
10452 
10467  Symbol grad,
10468  Symbol weight32,
10469  mx_float lr,
10470  mx_float wd = 0,
10471  mx_float rescale_grad = 1,
10472  mx_float clip_gradient = -1) {
10473  return Operator("mp_sgd_update")
10474  .SetParam("lr", lr)
10475  .SetParam("wd", wd)
10476  .SetParam("rescale_grad", rescale_grad)
10477  .SetParam("clip_gradient", clip_gradient)
10478  .SetInput("weight", weight)
10479  .SetInput("grad", grad)
10480  .SetInput("weight32", weight32)
10481  .CreateSymbol();
10482 }
10483 
10500  Symbol grad,
10501  Symbol mom,
10502  Symbol weight32,
10503  mx_float lr,
10504  mx_float momentum = 0,
10505  mx_float wd = 0,
10506  mx_float rescale_grad = 1,
10507  mx_float clip_gradient = -1) {
10508  return Operator("mp_sgd_mom_update")
10509  .SetParam("lr", lr)
10510  .SetParam("momentum", momentum)
10511  .SetParam("wd", wd)
10512  .SetParam("rescale_grad", rescale_grad)
10513  .SetParam("clip_gradient", clip_gradient)
10514  .SetInput("weight", weight)
10515  .SetInput("grad", grad)
10516  .SetInput("mom", mom)
10517  .SetInput("weight32", weight32)
10518  .CreateSymbol();
10519 }
10520 
10567 inline Symbol adam_update(Symbol weight,
10568  Symbol grad,
10569  Symbol mean,
10570  Symbol var,
10571  mx_float lr,
10572  mx_float beta1 = 0.9,
10573  mx_float beta2 = 0.999,
10574  mx_float epsilon = 1e-08,
10575  mx_float wd = 0,
10576  mx_float rescale_grad = 1,
10577  mx_float clip_gradient = -1) {
10578  return Operator("adam_update")
10579  .SetParam("lr", lr)
10580  .SetParam("beta1", beta1)
10581  .SetParam("beta2", beta2)
10582  .SetParam("epsilon", epsilon)
10583  .SetParam("wd", wd)
10584  .SetParam("rescale_grad", rescale_grad)
10585  .SetParam("clip_gradient", clip_gradient)
10586  .SetInput("weight", weight)
10587  .SetInput("grad", grad)
10588  .SetInput("mean", mean)
10589  .SetInput("var", var)
10590  .CreateSymbol();
10591 }
10592 
10646  Symbol grad,
10647  Symbol n,
10648  mx_float lr,
10649  mx_float gamma1 = 0.95,
10650  mx_float epsilon = 1e-08,
10651  mx_float wd = 0,
10652  mx_float rescale_grad = 1,
10653  mx_float clip_gradient = -1,
10654  mx_float clip_weights = -1) {
10655  return Operator("rmsprop_update")
10656  .SetParam("lr", lr)
10657  .SetParam("gamma1", gamma1)
10658  .SetParam("epsilon", epsilon)
10659  .SetParam("wd", wd)
10660  .SetParam("rescale_grad", rescale_grad)
10661  .SetParam("clip_gradient", clip_gradient)
10662  .SetParam("clip_weights", clip_weights)
10663  .SetInput("weight", weight)
10664  .SetInput("grad", grad)
10665  .SetInput("n", n)
10666  .CreateSymbol();
10667 }
10668 
10714  Symbol grad,
10715  Symbol n,
10716  Symbol g,
10717  Symbol delta,
10718  mx_float lr,
10719  mx_float gamma1 = 0.95,
10720  mx_float gamma2 = 0.9,
10721  mx_float epsilon = 1e-08,
10722  mx_float wd = 0,
10723  mx_float rescale_grad = 1,
10724  mx_float clip_gradient = -1,
10725  mx_float clip_weights = -1) {
10726  return Operator("rmspropalex_update")
10727  .SetParam("lr", lr)
10728  .SetParam("gamma1", gamma1)
10729  .SetParam("gamma2", gamma2)
10730  .SetParam("epsilon", epsilon)
10731  .SetParam("wd", wd)
10732  .SetParam("rescale_grad", rescale_grad)
10733  .SetParam("clip_gradient", clip_gradient)
10734  .SetParam("clip_weights", clip_weights)
10735  .SetInput("weight", weight)
10736  .SetInput("grad", grad)
10737  .SetInput("n", n)
10738  .SetInput("g", g)
10739  .SetInput("delta", delta)
10740  .CreateSymbol();
10741 }
10742 
10781 inline Symbol ftrl_update(Symbol weight,
10782  Symbol grad,
10783  Symbol z,
10784  Symbol n,
10785  mx_float lr,
10786  mx_float lamda1 = 0.01,
10787  mx_float beta = 1,
10788  mx_float wd = 0,
10789  mx_float rescale_grad = 1,
10790  mx_float clip_gradient = -1) {
10791  return Operator("ftrl_update")
10792  .SetParam("lr", lr)
10793  .SetParam("lamda1", lamda1)
10794  .SetParam("beta", beta)
10795  .SetParam("wd", wd)
10796  .SetParam("rescale_grad", rescale_grad)
10797  .SetParam("clip_gradient", clip_gradient)
10798  .SetInput("weight", weight)
10799  .SetInput("grad", grad)
10800  .SetInput("z", z)
10801  .SetInput("n", n)
10802  .CreateSymbol();
10803 }
10804 
10900 inline Symbol Pad(Symbol data,
10901  PadMode mode,
10902  Shape pad_width,
10903  double constant_value = 0) {
10904  static const char *PadModeValues[] = {
10905  "constant",
10906  "edge",
10907  "reflect"
10908  };
10909  return Operator("Pad")
10910  .SetParam("mode", PadModeValues[int(mode)])
10911  .SetParam("pad_width", pad_width)
10912  .SetParam("constant_value", constant_value)
10913  .SetInput("data", data)
10914  .CreateSymbol();
10915 }
10916 
10926  mx_float sparseness_target = 0.1,
10927  mx_float penalty = 0.001,
10928  mx_float momentum = 0.9) {
10929  return Operator("IdentityAttachKLSparseReg")
10930  .SetParam("sparseness_target", sparseness_target)
10931  .SetParam("penalty", penalty)
10932  .SetParam("momentum", momentum)
10933  .SetInput("data", data)
10934  .CreateSymbol();
10935 }
10936 
11008  int num_outputs,
11009  int axis = 1,
11010  bool squeeze_axis = 0) {
11011  return Operator("SliceChannel")
11012  .SetParam("num_outputs", num_outputs)
11013  .SetParam("axis", axis)
11014  .SetParam("squeeze_axis", squeeze_axis)
11015  .SetInput("data", data)
11016  .CreateSymbol();
11017 }
11018 
11056  Symbol label) {
11057  return Operator("softmax_cross_entropy")
11058  .SetInput("data", data)
11059  .SetInput("label", label)
11060  .CreateSymbol();
11061 }
11062 
11077 inline Symbol UpSampling(const std::vector<Symbol>& data,
11078  uint32_t scale,
11079  UpSamplingSampleType sample_type,
11080  int num_args,
11081  uint32_t num_filter = 0,
11083  uint64_t workspace = 512) {
11084  static const char *UpSamplingSampleTypeValues[] = {
11085  "bilinear",
11086  "nearest"
11087  };
11088  static const char *UpSamplingMultiInputModeValues[] = {
11089  "concat",
11090  "sum"
11091  };
11092  return Operator("UpSampling")
11093  .SetParam("scale", scale)
11094  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
11095  .SetParam("num_args", num_args)
11096  .SetParam("num_filter", num_filter)
11097  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
11098  .SetParam("workspace", workspace)
11099 (data)
11100  .CreateSymbol();
11101 }
11102 
11166  Symbol gamma,
11167  Symbol beta,
11168  Symbol moving_mean,
11169  Symbol moving_var,
11170  double eps = 0.001,
11171  mx_float momentum = 0.9,
11172  bool fix_gamma = 1,
11173  bool use_global_stats = 0,
11174  bool output_mean_var = 0,
11175  int axis = 1,
11176  bool cudnn_off = 0) {
11177  return Operator("BatchNorm")
11178  .SetParam("eps", eps)
11179  .SetParam("momentum", momentum)
11180  .SetParam("fix_gamma", fix_gamma)
11181  .SetParam("use_global_stats", use_global_stats)
11182  .SetParam("output_mean_var", output_mean_var)
11183  .SetParam("axis", axis)
11184  .SetParam("cudnn_off", cudnn_off)
11185  .SetInput("data", data)
11186  .SetInput("gamma", gamma)
11187  .SetInput("beta", beta)
11188  .SetInput("moving_mean", moving_mean)
11189  .SetInput("moving_var", moving_var)
11190  .CreateSymbol();
11191 }
11192 
11243  Symbol gamma,
11244  Symbol beta,
11245  mx_float eps = 0.001) {
11246  return Operator("InstanceNorm")
11247  .SetParam("eps", eps)
11248  .SetInput("data", data)
11249  .SetInput("gamma", gamma)
11250  .SetInput("beta", beta)
11251  .CreateSymbol();
11252 }
11253 
11268 inline Symbol RNN(Symbol data,
11269  Symbol parameters,
11270  Symbol state,
11271  Symbol state_cell,
11272  uint32_t state_size,
11273  uint32_t num_layers,
11274  RNNMode mode,
11275  bool bidirectional = 0,
11276  mx_float p = 0,
11277  bool state_outputs = 0) {
11278  static const char *RNNModeValues[] = {
11279  "gru",
11280  "lstm",
11281  "rnn_relu",
11282  "rnn_tanh"
11283  };
11284  return Operator("RNN")
11285  .SetParam("state_size", state_size)
11286  .SetParam("num_layers", num_layers)
11287  .SetParam("mode", RNNModeValues[int(mode)])
11288  .SetParam("bidirectional", bidirectional)
11289  .SetParam("p", p)
11290  .SetParam("state_outputs", state_outputs)
11291  .SetInput("data", data)
11292  .SetInput("parameters", parameters)
11293  .SetInput("state", state)
11294  .SetInput("state_cell", state_cell)
11295  .CreateSymbol();
11296 }
11297 
11326  Symbol weight,
11327  Symbol bias,
11328  Shape kernel,
11329  uint32_t num_filter,
11330  Shape stride = Shape(),
11331  Shape dilate = Shape(),
11332  Shape pad = Shape(),
11333  uint32_t num_group = 1,
11334  uint64_t workspace = 1024,
11335  bool no_bias = 0,
11337  bool cudnn_off = 0,
11339  static const char *Convolution_v1CudnnTuneValues[] = {
11340  "None",
11341  "fastest",
11342  "limited_workspace",
11343  "off"
11344  };
11345  static const char *Convolution_v1LayoutValues[] = {
11346  "None",
11347  "NCDHW",
11348  "NCHW",
11349  "NDHWC",
11350  "NHWC"
11351  };
11352  return Operator("Convolution_v1")
11353  .SetParam("kernel", kernel)
11354  .SetParam("num_filter", num_filter)
11355  .SetParam("stride", stride)
11356  .SetParam("dilate", dilate)
11357  .SetParam("pad", pad)
11358  .SetParam("num_group", num_group)
11359  .SetParam("workspace", workspace)
11360  .SetParam("no_bias", no_bias)
11361  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
11362  .SetParam("cudnn_off", cudnn_off)
11363  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
11364  .SetInput("data", data)
11365  .SetInput("weight", weight)
11366  .SetInput("bias", bias)
11367  .CreateSymbol();
11368 }
11369 
11389 inline Symbol Crop(const std::vector<Symbol>& data,
11390  int num_args,
11391  Shape offset = Shape(0,0),
11392  Shape h_w = Shape(0,0),
11393  bool center_crop = 0) {
11394  return Operator("Crop")
11395  .SetParam("num_args", num_args)
11396  .SetParam("offset", offset)
11397  .SetParam("h_w", h_w)
11398  .SetParam("center_crop", center_crop)
11399 (data)
11400  .CreateSymbol();
11401 }
11402 
11413  Symbol loc,
11414  SpatialTransformerTransformType transform_type,
11415  SpatialTransformerSamplerType sampler_type,
11416  Shape target_shape = Shape(0,0)) {
11417  static const char *SpatialTransformerTransformTypeValues[] = {
11418  "affine"
11419  };
11420  static const char *SpatialTransformerSamplerTypeValues[] = {
11421  "bilinear"
11422  };
11423  return Operator("SpatialTransformer")
11424  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
11425  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
11426  .SetParam("target_shape", target_shape)
11427  .SetInput("data", data)
11428  .SetInput("loc", loc)
11429  .CreateSymbol();
11430 }
11431 
11458  Symbol weight,
11459  Symbol bias,
11460  Shape kernel,
11461  uint32_t num_filter,
11462  Shape stride = Shape(),
11463  Shape dilate = Shape(),
11464  Shape pad = Shape(),
11465  Shape adj = Shape(),
11466  Shape target_shape = Shape(),
11467  uint32_t num_group = 1,
11468  uint64_t workspace = 512,
11469  bool no_bias = 1,
11471  bool cudnn_off = 0,
11473  static const char *DeconvolutionCudnnTuneValues[] = {
11474  "None",
11475  "fastest",
11476  "limited_workspace",
11477  "off"
11478  };
11479  static const char *DeconvolutionLayoutValues[] = {
11480  "None",
11481  "NCDHW",
11482  "NCHW",
11483  "NCW",
11484  "NDHWC",
11485  "NHWC"
11486  };
11487  return Operator("Deconvolution")
11488  .SetParam("kernel", kernel)
11489  .SetParam("num_filter", num_filter)
11490  .SetParam("stride", stride)
11491  .SetParam("dilate", dilate)
11492  .SetParam("pad", pad)
11493  .SetParam("adj", adj)
11494  .SetParam("target_shape", target_shape)
11495  .SetParam("num_group", num_group)
11496  .SetParam("workspace", workspace)
11497  .SetParam("no_bias", no_bias)
11498  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
11499  .SetParam("cudnn_off", cudnn_off)
11500  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
11501  .SetInput("data", data)
11502  .SetInput("weight", weight)
11503  .SetInput("bias", bias)
11504  .CreateSymbol();
11505 }
11506 
11602  Symbol label,
11603  mx_float grad_scale = 1,
11604  mx_float ignore_label = -1,
11605  bool multi_output = 0,
11606  bool use_ignore = 0,
11607  bool preserve_shape = 0,
11609  bool out_grad = 0,
11610  mx_float smooth_alpha = 0) {
11611  static const char *SoftmaxOutputNormalizationValues[] = {
11612  "batch",
11613  "null",
11614  "valid"
11615  };
11616  return Operator("SoftmaxOutput")
11617  .SetParam("grad_scale", grad_scale)
11618  .SetParam("ignore_label", ignore_label)
11619  .SetParam("multi_output", multi_output)
11620  .SetParam("use_ignore", use_ignore)
11621  .SetParam("preserve_shape", preserve_shape)
11622  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
11623  .SetParam("out_grad", out_grad)
11624  .SetParam("smooth_alpha", smooth_alpha)
11625  .SetInput("data", data)
11626  .SetInput("label", label)
11627  .CreateSymbol();
11628 }
11629 
11656 inline Symbol Softmax(Symbol data,
11657  mx_float grad_scale = 1,
11658  mx_float ignore_label = -1,
11659  bool multi_output = 0,
11660  bool use_ignore = 0,
11661  bool preserve_shape = 0,
11663  bool out_grad = 0,
11664  mx_float smooth_alpha = 0) {
11665  static const char *SoftmaxNormalizationValues[] = {
11666  "batch",
11667  "null",
11668  "valid"
11669  };
11670  return Operator("Softmax")
11671  .SetParam("grad_scale", grad_scale)
11672  .SetParam("ignore_label", ignore_label)
11673  .SetParam("multi_output", multi_output)
11674  .SetParam("use_ignore", use_ignore)
11675  .SetParam("preserve_shape", preserve_shape)
11676  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
11677  .SetParam("out_grad", out_grad)
11678  .SetParam("smooth_alpha", smooth_alpha)
11679  .SetInput("data", data)
11680  .CreateSymbol();
11681 }
11682 
11758  Symbol sequence_length,
11759  bool use_sequence_length = 0) {
11760  return Operator("SequenceReverse")
11761  .SetParam("use_sequence_length", use_sequence_length)
11762  .SetInput("data", data)
11763  .SetInput("sequence_length", sequence_length)
11764  .CreateSymbol();
11765 }
11766 
11821  Symbol sequence_length,
11822  bool use_sequence_length = 0) {
11823  return Operator("SequenceLast")
11824  .SetParam("use_sequence_length", use_sequence_length)
11825  .SetInput("data", data)
11826  .SetInput("sequence_length", sequence_length)
11827  .CreateSymbol();
11828 }
11829 
11878  Symbol data2,
11879  uint32_t kernel_size = 1,
11880  uint32_t max_displacement = 1,
11881  uint32_t stride1 = 1,
11882  uint32_t stride2 = 1,
11883  uint32_t pad_size = 0,
11884  bool is_multiply = 1) {
11885  return Operator("Correlation")
11886  .SetParam("kernel_size", kernel_size)
11887  .SetParam("max_displacement", max_displacement)
11888  .SetParam("stride1", stride1)
11889  .SetParam("stride2", stride2)
11890  .SetParam("pad_size", pad_size)
11891  .SetParam("is_multiply", is_multiply)
11892  .SetInput("data1", data1)
11893  .SetInput("data2", data2)
11894  .CreateSymbol();
11895 }
11896 
11912  Symbol label,
11913  mx_float margin = 1,
11914  mx_float regularization_coefficient = 1,
11915  bool use_linear = 0) {
11916  return Operator("SVMOutput")
11917  .SetParam("margin", margin)
11918  .SetParam("regularization_coefficient", regularization_coefficient)
11919  .SetParam("use_linear", use_linear)
11920  .SetInput("data", data)
11921  .SetInput("label", label)
11922  .CreateSymbol();
11923 }
11924 
11987  mx_float eps = 1e-10,
11989  static const char *L2NormalizationModeValues[] = {
11990  "channel",
11991  "instance",
11992  "spatial"
11993  };
11994  return Operator("L2Normalization")
11995  .SetParam("eps", eps)
11996  .SetParam("mode", L2NormalizationModeValues[int(mode)])
11997  .SetInput("data", data)
11998  .CreateSymbol();
11999 }
12000 
12027 inline Symbol LRN(Symbol data,
12028  uint32_t nsize,
12029  mx_float alpha = 0.0001,
12030  mx_float beta = 0.75,
12031  mx_float knorm = 2) {
12032  return Operator("LRN")
12033  .SetParam("nsize", nsize)
12034  .SetParam("alpha", alpha)
12035  .SetParam("beta", beta)
12036  .SetParam("knorm", knorm)
12037  .SetInput("data", data)
12038  .CreateSymbol();
12039 }
12040 
12074  Symbol weight,
12075  Symbol bias,
12076  int num_hidden,
12077  bool no_bias = 0,
12078  bool flatten = 1) {
12079  return Operator("FullyConnected")
12080  .SetParam("num_hidden", num_hidden)
12081  .SetParam("no_bias", no_bias)
12082  .SetParam("flatten", flatten)
12083  .SetInput("data", data)
12084  .SetInput("weight", weight)
12085  .SetInput("bias", bias)
12086  .CreateSymbol();
12087 }
12088 
12166  Symbol sequence_length,
12167  bool use_sequence_length = 0,
12168  mx_float value = 0) {
12169  return Operator("SequenceMask")
12170  .SetParam("use_sequence_length", use_sequence_length)
12171  .SetParam("value", value)
12172  .SetInput("data", data)
12173  .SetInput("sequence_length", sequence_length)
12174  .CreateSymbol();
12175 }
12176 
12187  GridGeneratorTransformType transform_type,
12188  Shape target_shape = Shape(0,0)) {
12189  static const char *GridGeneratorTransformTypeValues[] = {
12190  "affine",
12191  "warp"
12192  };
12193  return Operator("GridGenerator")
12194  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
12195  .SetParam("target_shape", target_shape)
12196  .SetInput("data", data)
12197  .CreateSymbol();
12198 }
12199 
12251  Shape kernel,
12252  Pooling_v1PoolType pool_type,
12253  bool global_pool = 0,
12255  Shape stride = Shape(),
12256  Shape pad = Shape()) {
12257  static const char *Pooling_v1PoolTypeValues[] = {
12258  "avg",
12259  "max",
12260  "sum"
12261  };
12262  static const char *Pooling_v1PoolingConventionValues[] = {
12263  "full",
12264  "valid"
12265  };
12266  return Operator("Pooling_v1")
12267  .SetParam("kernel", kernel)
12268  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
12269  .SetParam("global_pool", global_pool)
12270  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
12271  .SetParam("stride", stride)
12272  .SetParam("pad", pad)
12273  .SetInput("data", data)
12274  .CreateSymbol();
12275 }
12276 
12371  Symbol weight,
12372  Symbol bias,
12373  Shape kernel,
12374  uint32_t num_filter,
12375  Shape stride = Shape(),
12376  Shape dilate = Shape(),
12377  Shape pad = Shape(),
12378  uint32_t num_group = 1,
12379  uint64_t workspace = 1024,
12380  bool no_bias = 0,
12382  bool cudnn_off = 0,
12384  static const char *ConvolutionCudnnTuneValues[] = {
12385  "None",
12386  "fastest",
12387  "limited_workspace",
12388  "off"
12389  };
12390  static const char *ConvolutionLayoutValues[] = {
12391  "None",
12392  "NCDHW",
12393  "NCHW",
12394  "NCW",
12395  "NDHWC",
12396  "NHWC"
12397  };
12398  return Operator("Convolution")
12399  .SetParam("kernel", kernel)
12400  .SetParam("num_filter", num_filter)
12401  .SetParam("stride", stride)
12402  .SetParam("dilate", dilate)
12403  .SetParam("pad", pad)
12404  .SetParam("num_group", num_group)
12405  .SetParam("workspace", workspace)
12406  .SetParam("no_bias", no_bias)
12407  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
12408  .SetParam("cudnn_off", cudnn_off)
12409  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
12410  .SetInput("data", data)
12411  .SetInput("weight", weight)
12412  .SetInput("bias", bias)
12413  .CreateSymbol();
12414 }
12415 
12496  Symbol grid) {
12497  return Operator("BilinearSampler")
12498  .SetInput("data", data)
12499  .SetInput("grid", grid)
12500  .CreateSymbol();
12501 }
12502 
12555 inline Symbol Pooling(Symbol data,
12556  Shape kernel,
12557  PoolingPoolType pool_type,
12558  bool global_pool = 0,
12559  bool cudnn_off = 0,
12561  Shape stride = Shape(),
12562  Shape pad = Shape()) {
12563  static const char *PoolingPoolTypeValues[] = {
12564  "avg",
12565  "max",
12566  "sum"
12567  };
12568  static const char *PoolingPoolingConventionValues[] = {
12569  "full",
12570  "valid"
12571  };
12572  return Operator("Pooling")
12573  .SetParam("kernel", kernel)
12574  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
12575  .SetParam("global_pool", global_pool)
12576  .SetParam("cudnn_off", cudnn_off)
12577  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
12578  .SetParam("stride", stride)
12579  .SetParam("pad", pad)
12580  .SetInput("data", data)
12581  .CreateSymbol();
12582 }
12583 
12622 inline Symbol Dropout(Symbol data,
12623  mx_float p = 0.5,
12625  static const char *DropoutModeValues[] = {
12626  "always",
12627  "training"
12628  };
12629  return Operator("Dropout")
12630  .SetParam("p", p)
12631  .SetParam("mode", DropoutModeValues[int(mode)])
12632  .SetInput("data", data)
12633  .CreateSymbol();
12634 }
12635 
12654  ActivationActType act_type) {
12655  static const char *ActivationActTypeValues[] = {
12656  "relu",
12657  "sigmoid",
12658  "softrelu",
12659  "tanh"
12660  };
12661  return Operator("Activation")
12662  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
12663  .SetInput("data", data)
12664  .CreateSymbol();
12665 }
12666 
12723  Symbol rois,
12724  Shape pooled_size,
12725  mx_float spatial_scale) {
12726  return Operator("ROIPooling")
12727  .SetParam("pooled_size", pooled_size)
12728  .SetParam("spatial_scale", spatial_scale)
12729  .SetInput("data", data)
12730  .SetInput("rois", rois)
12731  .CreateSymbol();
12732 }
12733 
12758  Symbol label,
12759  mx_float grad_scale = 1) {
12760  return Operator("LinearRegressionOutput")
12761  .SetParam("grad_scale", grad_scale)
12762  .SetInput("data", data)
12763  .SetInput("label", label)
12764  .CreateSymbol();
12765 }
12766 
12792  Symbol label,
12793  mx_float grad_scale = 1) {
12794  return Operator("MAERegressionOutput")
12795  .SetParam("grad_scale", grad_scale)
12796  .SetInput("data", data)
12797  .SetInput("label", label)
12798  .CreateSymbol();
12799 }
12800 
12826  Symbol label,
12827  mx_float grad_scale = 1) {
12828  return Operator("LogisticRegressionOutput")
12829  .SetParam("grad_scale", grad_scale)
12830  .SetInput("data", data)
12831  .SetInput("label", label)
12832  .CreateSymbol();
12833 }
12834 
12869  static const char *SoftmaxActivationModeValues[] = {
12870  "channel",
12871  "instance"
12872  };
12873  return Operator("SoftmaxActivation")
12874  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
12875  .SetInput("data", data)
12876  .CreateSymbol();
12877 }
12878 
12912 inline Symbol MakeLoss(Symbol data,
12913  mx_float grad_scale = 1,
12914  mx_float valid_thresh = 0,
12916  static const char *MakeLossNormalizationValues[] = {
12917  "batch",
12918  "null",
12919  "valid"
12920  };
12921  return Operator("MakeLoss")
12922  .SetParam("grad_scale", grad_scale)
12923  .SetParam("valid_thresh", valid_thresh)
12924  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
12925  .SetInput("data", data)
12926  .CreateSymbol();
12927 }
12928 
12937  Symbol rhs) {
12938  return Operator("choose_element_0index")
12939  .SetInput("lhs", lhs)
12940  .SetInput("rhs", rhs)
12941  .CreateSymbol();
12942 }
12943 
12953  Symbol mhs,
12954  Symbol rhs) {
12955  return Operator("fill_element_0index")
12956  .SetInput("lhs", lhs)
12957  .SetInput("mhs", mhs)
12958  .SetInput("rhs", rhs)
12959  .CreateSymbol();
12960 }
12961 
12962 } //namespace cpp
12963 } //namespace mxnet
12964 #endif // MXNET_CPP_OP_H_
Symbol fix(const std::string &symbol_name, Symbol data)
Definition: op.h:1669
Symbol broadcast_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:872
Symbol arcsin(const std::string &symbol_name, Symbol data)
Definition: op.h:3383
Symbol arccosh(const std::string &symbol_name, Symbol data)
Definition: op.h:3600
Symbol arctan(const std::string &symbol_name, Symbol data)
Definition: op.h:3435
Symbol SwapAxis(const std::string &symbol_name, Symbol data, uint32_t dim1=0, uint32_t dim2=0)
Definition: op.h:3681
Symbol cast_storage(const std::string &symbol_name, Symbol data, Cast_storageStype stype)
Definition: op.h:3270
Symbol add_n(const std::string &symbol_name, const std::vector< Symbol > &args)
Definition: op.h:963
Symbol log1p(const std::string &symbol_name, Symbol data)
Definition: op.h:1916
Symbol Convolution_v1(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=0, Convolution_v1CudnnTune cudnn_tune=Convolution_v1CudnnTune::kNone, bool cudnn_off=0, Convolution_v1Layout layout=Convolution_v1Layout::kNone)
Definition: op.h:4944
Symbol BatchNorm_v1(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001, mx_float momentum=0.9, bool fix_gamma=1, bool use_global_stats=0, bool output_mean_var=0)
Definition: op.h:3803
SoftmaxActivationMode
Definition: op.h:6637
Symbol mp_sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol weight32, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:3997
Symbol SpatialTransformer(const std::string &symbol_name, Symbol data, Symbol loc, SpatialTransformerTransformType transform_type, SpatialTransformerSamplerType sampler_type, Shape target_shape=Shape(0, 0))
Definition: op.h:5047
Symbol RNN(const std::string &symbol_name, Symbol data, Symbol parameters, Symbol state, Symbol state_cell, uint32_t state_size, uint32_t num_layers, RNNMode mode, bool bidirectional=0, mx_float p=0, bool state_outputs=0)
Definition: op.h:4858
Symbol exp(const std::string &symbol_name, Symbol data)
Definition: op.h:1828
Symbol transpose(const std::string &symbol_name, Symbol data, Shape axes=Shape())
Definition: op.h:389
Symbol Correlation(const std::string &symbol_name, Symbol data1, Symbol data2, uint32_t kernel_size=1, uint32_t max_displacement=1, uint32_t stride1=1, uint32_t stride2=1, uint32_t pad_size=0, bool is_multiply=1)
Definition: op.h:5560
Symbol clip(const std::string &symbol_name, Symbol data, mx_float a_min, mx_float a_max)
Definition: op.h:548
Symbol elemwise_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2622
Symbol ROIPooling(const std::string &symbol_name, Symbol data, Symbol rois, Shape pooled_size, mx_float spatial_scale)
Definition: op.h:6514
Symbol broadcast_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:903
Symbol BatchNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, Symbol moving_mean, Symbol moving_var, double eps=0.001, mx_float momentum=0.9, bool fix_gamma=1, bool use_global_stats=0, bool output_mean_var=0, int axis=1, bool cudnn_off=0)
Definition: op.h:4742
Convolution_v1Layout
Definition: op.h:4908
Symbol mp_sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, Symbol weight32, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4032
Symbol broadcast_lesser(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3118
Symbol fill_element_0index(const std::string &symbol_name, Symbol lhs, Symbol mhs, Symbol rhs)
Definition: op.h:6776
Symbol SliceChannel(const std::string &symbol_name, Symbol data, int num_outputs, int axis=1, bool squeeze_axis=0)
Definition: op.h:4563
Symbol broadcast_not_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3025
TakeMode
Definition: op.h:2718
Symbol Embedding(const std::string &symbol_name, Symbol data, Symbol weight, int input_dim, int output_dim, EmbeddingDtype dtype=EmbeddingDtype::kFloat32)
Definition: op.h:2692
Symbol sort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=1)
Definition: op.h:2487
Symbol ftrl_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol z, Symbol n, mx_float lr, mx_float lamda1=0.01, mx_float beta=1, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4322
Symbol reciprocal(const std::string &symbol_name, Symbol data)
Definition: op.h:1449
TopkRetTyp
Definition: op.h:2386
namespace of mxnet
Definition: base.h:126
Symbol reshape_like(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1358
Pooling_v1PoolingConvention
Definition: op.h:5922
Symbol broadcast_lesser_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3149
Operator & SetInput(const std::string &name, Symbol symbol)
add an input symbol
Symbol InstanceNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001)
Definition: op.h:4821
Symbol sign(const std::string &symbol_name, Symbol data)
Definition: op.h:1501
Symbol argmax(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=0)
Definition: op.h:1001
GridGeneratorTransformType
Definition: op.h:5882
Symbol argsort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=1)
Definition: op.h:2528
Cast_storageStype
Definition: op.h:3221
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:42
Symbol ones_like(const std::string &symbol_name, Symbol data)
Definition: op.h:771
RNNMode
Definition: op.h:4836
Symbol SequenceMask(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=0, mx_float value=0)
Definition: op.h:5866
PadMode
Definition: op.h:4350
Symbol smooth_l1(const std::string &symbol_name, Symbol data, mx_float scalar)
Definition: op.h:3210
Symbol where(const std::string &symbol_name, Symbol condition, Symbol x, Symbol y)
Definition: op.h:3174
Symbol Dropout(const std::string &symbol_name, Symbol data, mx_float p=0.5, DropoutMode mode=DropoutMode::kTraining)
Definition: op.h:6401
Symbol expm1(const std::string &symbol_name, Symbol data)
Definition: op.h:1940
Symbol elemwise_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2553
PoolingPoolType
Definition: op.h:6259
Symbol relu(const std::string &symbol_name, Symbol data)
Definition: op.h:1246
Symbol reverse(const std::string &symbol_name, Symbol data, Shape axis)
Definition: op.h:681
Symbol rsqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1754
SpatialTransformerTransformType
Definition: op.h:5027
ActivationActType
Definition: op.h:6418
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1728
Symbol rint(const std::string &symbol_name, Symbol data)
Definition: op.h:1557
Symbol IdentityAttachKLSparseReg(const std::string &symbol_name, Symbol data, mx_float sparseness_target=0.1, mx_float penalty=0.001, mx_float momentum=0.9)
Definition: op.h:4479
Symbol sinh(const std::string &symbol_name, Symbol data)
Definition: op.h:3510
Symbol scatter_nd(const std::string &symbol_name, Symbol data, Symbol indices, Shape shape)
Definition: op.h:2961
Symbol broadcast_greater_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3087
Symbol LRN(const std::string &symbol_name, Symbol data, uint32_t nsize, mx_float alpha=0.0001, mx_float beta=0.75, mx_float knorm=2)
Definition: op.h:5724
Symbol sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:3962
Symbol arcsinh(const std::string &symbol_name, Symbol data)
Definition: op.h:3580
Symbol max(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=0, bool exclude=0)
Definition: op.h:2230
Symbol MAERegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:6587
Symbol sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:3903
PoolingPoolingConvention
Definition: op.h:6267
Symbol Reshape(const std::string &symbol_name, Symbol data, Shape shape=Shape(), bool reverse=0, Shape target_shape=Shape(), bool keep_highest=0)
Definition: op.h:301
Symbol broadcast_minimum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:179
Symbol Deconvolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), Shape adj=Shape(), Shape target_shape=Shape(), uint32_t num_group=1, uint64_t workspace=512, bool no_bias=1, DeconvolutionCudnnTune cudnn_tune=DeconvolutionCudnnTune::kNone, bool cudnn_off=0, DeconvolutionLayout layout=DeconvolutionLayout::kNone)
Definition: op.h:5114
Symbol broadcast_maximum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:146
Symbol Cast(const std::string &symbol_name, Symbol data, CastDtype dtype)
Definition: op.h:1396
DeconvolutionLayout
Definition: op.h:5079
Symbol trunc(const std::string &symbol_name, Symbol data)
Definition: op.h:1642
Pooling_v1PoolType
Definition: op.h:5914
Symbol round(const std::string &symbol_name, Symbol data)
Definition: op.h:1527
Symbol log_softmax(const std::string &symbol_name, Symbol data, int axis=-1)
Definition: op.h:82
Symbol cos(const std::string &symbol_name, Symbol data)
Definition: op.h:3328
Symbol L2Normalization(const std::string &symbol_name, Symbol data, mx_float eps=1e-10, L2NormalizationMode mode=L2NormalizationMode::kInstance)
Definition: op.h:5681
Symbol zeros_like(const std::string &symbol_name, Symbol data)
Definition: op.h:747
EmbeddingDtype
Definition: op.h:2633
Symbol broadcast_mod(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:934
Symbol Crop(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, Shape offset=Shape(0, 0), Shape h_w=Shape(0, 0), bool center_crop=0)
Definition: op.h:5010
Symbol cbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1778
Symbol sum(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=0, bool exclude=0)
Definition: op.h:2041
operator helper functions
Symbol tanh(const std::string &symbol_name, Symbol data)
Definition: op.h:3557
Symbol broadcast_to(const std::string &symbol_name, Symbol data, Shape shape=Shape())
Definition: op.h:2348
Symbol elemwise_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2576
DropoutMode
Definition: op.h:6357
Symbol topk(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), int k=1, TopkRetTyp ret_typ=TopkRetTyp::kIndices, bool is_ascend=0)
Definition: op.h:2434
Symbol MakeLoss(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float valid_thresh=0, MakeLossNormalization normalization=MakeLossNormalization::kNull)
Definition: op.h:6732
Symbol log(const std::string &symbol_name, Symbol data)
Definition: op.h:1849
Symbol Pooling_v1(const std::string &symbol_name, Symbol data, Shape kernel, Pooling_v1PoolType pool_type, bool global_pool=0, Pooling_v1PoolingConvention pooling_convention=Pooling_v1PoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:5978
Symbol SequenceLast(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=0)
Definition: op.h:5501
Symbol sigmoid(const std::string &symbol_name, Symbol data)
Definition: op.h:1268
CastDtype
Definition: op.h:1369
ConvolutionLayout
Definition: op.h:6018
Symbol LogisticRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:6623
Symbol gamma(const std::string &symbol_name, Symbol data)
Definition: op.h:1958
Symbol sin(const std::string &symbol_name, Symbol data)
Definition: op.h:3304
Symbol SVMOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float margin=1, mx_float regularization_coefficient=1, bool use_linear=0)
Definition: op.h:5596
UpSamplingMultiInputMode
Definition: op.h:4632
Symbol CreateSymbol(const std::string &name="")
create a Symbol from the current operator
Symbol elemwise_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2602
SpatialTransformerSamplerType
Definition: op.h:5033
Symbol nansum(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=0, bool exclude=0)
Definition: op.h:2154
Symbol Pad(const std::string &symbol_name, Symbol data, PadMode mode, Shape pad_width, double constant_value=0)
Definition: op.h:4452
Symbol square(const std::string &symbol_name, Symbol data)
Definition: op.h:1699
One_hotDtype
Definition: op.h:2820
UpSamplingSampleType
Definition: op.h:4624
Symbol norm(const std::string &symbol_name, Symbol data)
Definition: op.h:2374
Symbol rmsprop_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, mx_float lr, mx_float gamma1=0.95, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:4182
Symbol LeakyReLU(const std::string &symbol_name, Symbol data, LeakyReLUActType act_type=LeakyReLUActType::kLeaky, mx_float slope=0.25, mx_float lower_bound=0.125, mx_float upper_bound=0.334)
Definition: op.h:3727
Symbol make_loss(const std::string &symbol_name, Symbol data)
Definition: op.h:1344
Symbol argmin(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=0)
Definition: op.h:1043
Symbol SoftmaxActivation(const std::string &symbol_name, Symbol data, SoftmaxActivationMode mode=SoftmaxActivationMode::kInstance)
Definition: op.h:6675
Symbol mean(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=0, bool exclude=0)
Definition: op.h:2078
Symbol slice(const std::string &symbol_name, Symbol data, Shape begin, Shape end)
Definition: op.h:458
Symbol broadcast_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2994
Symbol batch_dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=0, bool transpose_b=0)
Definition: op.h:1215
Symbol broadcast_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:805
Symbol adam_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mean, Symbol var, mx_float lr, mx_float beta1=0.9, mx_float beta2=0.999, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4102
Operator & SetParam(const std::string &name, const T &value)
set config parameters
Definition: operator.h:57
Symbol tan(const std::string &symbol_name, Symbol data)
Definition: op.h:3355
Convolution_v1CudnnTune
Definition: op.h:4898
Symbol repeat(const std::string &symbol_name, Symbol data, int repeats, dmlc::optional< int > axis=dmlc::optional< int >())
Definition: op.h:593
Symbol slice_axis(const std::string &symbol_name, Symbol data, int axis, int begin, dmlc::optional< int > end)
Definition: op.h:501
Symbol expand_dims(const std::string &symbol_name, Symbol data, int axis)
Definition: op.h:413
Symbol arctanh(const std::string &symbol_name, Symbol data)
Definition: op.h:3623
Symbol softmax_cross_entropy(const std::string &symbol_name, Symbol data, Symbol label)
Definition: op.h:4613
Symbol Convolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=0, ConvolutionCudnnTune cudnn_tune=ConvolutionCudnnTune::kNone, bool cudnn_off=0, ConvolutionLayout layout=ConvolutionLayout::kNone)
Definition: op.h:6121
Symbol broadcast_axis(const std::string &symbol_name, Symbol data, Shape axis=Shape(), Shape size=Shape())
Definition: op.h:2309
Symbol dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=0, bool transpose_b=0)
Definition: op.h:1180
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:1475
Symbol cosh(const std::string &symbol_name, Symbol data)
Definition: op.h:3532
Symbol gather_nd(const std::string &symbol_name, Symbol data, Symbol indices)
Definition: op.h:2922
Symbol FullyConnected(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, int num_hidden, bool no_bias=0, bool flatten=1)
Definition: op.h:5772
Symbol BilinearSampler(const std::string &symbol_name, Symbol data, Symbol grid)
Definition: op.h:6248
Symbol Custom(const std::string &symbol_name, const std::vector< Symbol > &data, const std::string &op_type)
Definition: op.h:3645
Symbol SequenceReverse(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=0)
Definition: op.h:5436
Symbol broadcast_hypot(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:218
Symbol SoftmaxOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=0, bool use_ignore=0, bool preserve_shape=0, SoftmaxOutputNormalization normalization=SoftmaxOutputNormalization::kNull, bool out_grad=0, mx_float smooth_alpha=0)
Definition: op.h:5268
Symbol UpSampling(const std::string &symbol_name, const std::vector< Symbol > &data, uint32_t scale, UpSamplingSampleType sample_type, int num_args, uint32_t num_filter=0, UpSamplingMultiInputMode multi_input_mode=UpSamplingMultiInputMode::kConcat, uint64_t workspace=512)
Definition: op.h:4652
Symbol Activation(const std::string &symbol_name, Symbol data, ActivationActType act_type)
Definition: op.h:6443
float mx_float
manually define float
Definition: c_api.h:59
Symbol radians(const std::string &symbol_name, Symbol data)
Definition: op.h:3485
Symbol Concat(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int dim=1)
Definition: op.h:3865
L2NormalizationMode
Definition: op.h:5613
Symbol stack(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int axis=0)
Definition: op.h:713
Symbol min(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=0, bool exclude=0)
Definition: op.h:2267
Symbol floor(const std::string &symbol_name, Symbol data)
Definition: op.h:1613
Symbol broadcast_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:841
Symbol take(const std::string &symbol_name, Symbol a, Symbol indices, int axis=0, TakeMode mode=TakeMode::kClip)
Definition: op.h:2763
Symbol ceil(const std::string &symbol_name, Symbol data)
Definition: op.h:1585
Symbol gammaln(const std::string &symbol_name, Symbol data)
Definition: op.h:1976
Symbol tile(const std::string &symbol_name, Symbol data, Shape reps)
Definition: op.h:649
Symbol Softmax(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=0, bool use_ignore=0, bool preserve_shape=0, SoftmaxNormalization normalization=SoftmaxNormalization::kNull, bool out_grad=0, mx_float smooth_alpha=0)
Definition: op.h:5333
Symbol rmspropalex_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, Symbol g, Symbol delta, mx_float lr, mx_float gamma1=0.95, mx_float gamma2=0.9, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:4252
SoftmaxNormalization
Definition: op.h:5300
Symbol softmax(const std::string &symbol_name, Symbol data, int axis=-1)
Definition: op.h:51
DeconvolutionCudnnTune
Definition: op.h:5070
ConvolutionCudnnTune
Definition: op.h:6008
definition of shape
Symbol broadcast_greater(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3056
Symbol rcbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1802
Symbol prod(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=0, bool exclude=0)
Definition: op.h:2115
Symbol Pooling(const std::string &symbol_name, Symbol data, Shape kernel, PoolingPoolType pool_type, bool global_pool=0, bool cudnn_off=0, PoolingPoolingConvention pooling_convention=PoolingPoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:6325
Symbol broadcast_power(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:113
Symbol nanprod(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=0, bool exclude=0)
Definition: op.h:2193
SoftmaxOutputNormalization
Definition: op.h:5167
Symbol Flatten(const std::string &symbol_name, Symbol data)
Definition: op.h:346
Symbol BlockGrad(const std::string &symbol_name, Symbol data)
Definition: op.h:1308
LeakyReLUActType
Definition: op.h:3694
Symbol arccos(const std::string &symbol_name, Symbol data)
Definition: op.h:3408
Symbol argmax_channel(const std::string &symbol_name, Symbol data)
Definition: op.h:1076
Symbol batch_take(const std::string &symbol_name, Symbol a, Symbol indices)
Definition: op.h:2809
Symbol LinearRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:6551
Symbol choose_element_0index(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:6758
Symbol degrees(const std::string &symbol_name, Symbol data)
Definition: op.h:3460
Symbol one_hot(const std::string &symbol_name, Symbol indices, int depth, double on_value=1, double off_value=0, One_hotDtype dtype=One_hotDtype::kFloat32)
Definition: op.h:2872
Symbol pick(const std::string &symbol_name, Symbol data, Symbol index, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=0)
Definition: op.h:1128
Symbol negative(const std::string &symbol_name, Symbol data)
Definition: op.h:1426
Symbol GridGenerator(const std::string &symbol_name, Symbol data, GridGeneratorTransformType transform_type, Shape target_shape=Shape(0, 0))
Definition: op.h:5897
Operator interface.
Definition: operator.h:42
Symbol interface.
Definition: symbol.h:71
MakeLossNormalization
Definition: op.h:6692
Symbol log10(const std::string &symbol_name, Symbol data)
Definition: op.h:1870
Symbol log2(const std::string &symbol_name, Symbol data)
Definition: op.h:1891