mxnet
op.h
Go to the documentation of this file.
1 
8 #ifndef MXNET_CPP_OP_H_
9 #define MXNET_CPP_OP_H_
10 
11 #include <string>
12 #include <vector>
13 #include "mxnet-cpp/base.h"
14 #include "mxnet-cpp/shape.h"
15 #include "mxnet-cpp/op_util.h"
16 #include "mxnet-cpp/operator.h"
17 #include "dmlc/optional.h"
18 
19 namespace mxnet {
20 namespace cpp {
21 
51 inline Symbol softmax(const std::string& symbol_name,
52  Symbol data,
53  int axis = -1) {
54  return Operator("softmax")
55  .SetParam("axis", axis)
56  .SetInput("data", data)
57  .CreateSymbol(symbol_name);
58 }
59 
82 inline Symbol log_softmax(const std::string& symbol_name,
83  Symbol data,
84  int axis = -1) {
85  return Operator("log_softmax")
86  .SetParam("axis", axis)
87  .SetInput("data", data)
88  .CreateSymbol(symbol_name);
89 }
90 
94  kInt32 = 0
95 };
96 
135 inline Symbol sample_multinomial(const std::string& symbol_name,
136  Symbol data,
137  Shape shape = Shape(),
138  bool get_prob = false,
140  static const char *Sample_multinomialDtypeValues[] = {
141  "int32"
142  };
143  return Operator("sample_multinomial")
144  .SetParam("shape", shape)
145  .SetParam("get_prob", get_prob)
146  .SetParam("dtype", Sample_multinomialDtypeValues[int(dtype)])
147  .SetInput("data", data)
148  .CreateSymbol(symbol_name);
149 }
150 
154  kNone = 0,
155  kFloat16 = 1,
156  kFloat32 = 2,
157  kFloat64 = 3
158 };
159 
196 inline Symbol sample_uniform(const std::string& symbol_name,
197  Symbol low,
198  Symbol high,
199  Shape shape = Shape(),
201  static const char *Sample_uniformDtypeValues[] = {
202  "None",
203  "float16",
204  "float32",
205  "float64"
206  };
207  return Operator("sample_uniform")
208  .SetParam("shape", shape)
209  .SetParam("dtype", Sample_uniformDtypeValues[int(dtype)])
210  .SetInput("low", low)
211  .SetInput("high", high)
212  .CreateSymbol(symbol_name);
213 }
214 
217 enum class Sample_normalDtype {
218  kNone = 0,
219  kFloat16 = 1,
220  kFloat32 = 2,
221  kFloat64 = 3
222 };
223 
260 inline Symbol sample_normal(const std::string& symbol_name,
261  Symbol mu,
262  Symbol sigma,
263  Shape shape = Shape(),
265  static const char *Sample_normalDtypeValues[] = {
266  "None",
267  "float16",
268  "float32",
269  "float64"
270  };
271  return Operator("sample_normal")
272  .SetParam("shape", shape)
273  .SetParam("dtype", Sample_normalDtypeValues[int(dtype)])
274  .SetInput("mu", mu)
275  .SetInput("sigma", sigma)
276  .CreateSymbol(symbol_name);
277 }
278 
281 enum class Sample_gammaDtype {
282  kNone = 0,
283  kFloat16 = 1,
284  kFloat32 = 2,
285  kFloat64 = 3
286 };
287 
324 inline Symbol sample_gamma(const std::string& symbol_name,
325  Symbol alpha,
326  Symbol beta,
327  Shape shape = Shape(),
329  static const char *Sample_gammaDtypeValues[] = {
330  "None",
331  "float16",
332  "float32",
333  "float64"
334  };
335  return Operator("sample_gamma")
336  .SetParam("shape", shape)
337  .SetParam("dtype", Sample_gammaDtypeValues[int(dtype)])
338  .SetInput("alpha", alpha)
339  .SetInput("beta", beta)
340  .CreateSymbol(symbol_name);
341 }
342 
346  kNone = 0,
347  kFloat16 = 1,
348  kFloat32 = 2,
349  kFloat64 = 3
350 };
351 
386 inline Symbol sample_exponential(const std::string& symbol_name,
387  Symbol lam,
388  Shape shape = Shape(),
390  static const char *Sample_exponentialDtypeValues[] = {
391  "None",
392  "float16",
393  "float32",
394  "float64"
395  };
396  return Operator("sample_exponential")
397  .SetParam("shape", shape)
398  .SetParam("dtype", Sample_exponentialDtypeValues[int(dtype)])
399  .SetInput("lam", lam)
400  .CreateSymbol(symbol_name);
401 }
402 
406  kNone = 0,
407  kFloat16 = 1,
408  kFloat32 = 2,
409  kFloat64 = 3
410 };
411 
448 inline Symbol sample_poisson(const std::string& symbol_name,
449  Symbol lam,
450  Shape shape = Shape(),
452  static const char *Sample_poissonDtypeValues[] = {
453  "None",
454  "float16",
455  "float32",
456  "float64"
457  };
458  return Operator("sample_poisson")
459  .SetParam("shape", shape)
460  .SetParam("dtype", Sample_poissonDtypeValues[int(dtype)])
461  .SetInput("lam", lam)
462  .CreateSymbol(symbol_name);
463 }
464 
468  kNone = 0,
469  kFloat16 = 1,
470  kFloat32 = 2,
471  kFloat64 = 3
472 };
473 
512 inline Symbol sample_negative_binomial(const std::string& symbol_name,
513  Symbol k,
514  Symbol p,
515  Shape shape = Shape(),
517  static const char *Sample_negative_binomialDtypeValues[] = {
518  "None",
519  "float16",
520  "float32",
521  "float64"
522  };
523  return Operator("sample_negative_binomial")
524  .SetParam("shape", shape)
525  .SetParam("dtype", Sample_negative_binomialDtypeValues[int(dtype)])
526  .SetInput("k", k)
527  .SetInput("p", p)
528  .CreateSymbol(symbol_name);
529 }
530 
534  kNone = 0,
535  kFloat16 = 1,
536  kFloat32 = 2,
537  kFloat64 = 3
538 };
539 
578 inline Symbol sample_generalized_negative_binomial(const std::string& symbol_name,
579  Symbol mu,
580  Symbol alpha,
581  Shape shape = Shape(),
583  static const char *Sample_generalized_negative_binomialDtypeValues[] = {
584  "None",
585  "float16",
586  "float32",
587  "float64"
588  };
589  return Operator("sample_generalized_negative_binomial")
590  .SetParam("shape", shape)
591  .SetParam("dtype", Sample_generalized_negative_binomialDtypeValues[int(dtype)])
592  .SetInput("mu", mu)
593  .SetInput("alpha", alpha)
594  .CreateSymbol(symbol_name);
595 }
596 
600  kNone = 0,
601  kFloat16 = 1,
602  kFloat32 = 2,
603  kFloat64 = 3
604 };
605 
630 inline Symbol random_uniform(const std::string& symbol_name,
631  mx_float low = 0,
632  mx_float high = 1,
633  Shape shape = Shape(),
634  const std::string& ctx = "",
636  static const char *Random_uniformDtypeValues[] = {
637  "None",
638  "float16",
639  "float32",
640  "float64"
641  };
642  return Operator("random_uniform")
643  .SetParam("low", low)
644  .SetParam("high", high)
645  .SetParam("shape", shape)
646  .SetParam("dtype", Random_uniformDtypeValues[int(dtype)])
647  .CreateSymbol(symbol_name);
648 }
649 
652 enum class Random_normalDtype {
653  kNone = 0,
654  kFloat16 = 1,
655  kFloat32 = 2,
656  kFloat64 = 3
657 };
658 
681 inline Symbol random_normal(const std::string& symbol_name,
682  mx_float loc = 0,
683  mx_float scale = 1,
684  Shape shape = Shape(),
685  const std::string& ctx = "",
687  static const char *Random_normalDtypeValues[] = {
688  "None",
689  "float16",
690  "float32",
691  "float64"
692  };
693  return Operator("random_normal")
694  .SetParam("loc", loc)
695  .SetParam("scale", scale)
696  .SetParam("shape", shape)
697  .SetParam("dtype", Random_normalDtypeValues[int(dtype)])
698  .CreateSymbol(symbol_name);
699 }
700 
703 enum class Random_gammaDtype {
704  kNone = 0,
705  kFloat16 = 1,
706  kFloat32 = 2,
707  kFloat64 = 3
708 };
709 
730 inline Symbol random_gamma(const std::string& symbol_name,
731  mx_float alpha = 1,
732  mx_float beta = 1,
733  Shape shape = Shape(),
734  const std::string& ctx = "",
736  static const char *Random_gammaDtypeValues[] = {
737  "None",
738  "float16",
739  "float32",
740  "float64"
741  };
742  return Operator("random_gamma")
743  .SetParam("alpha", alpha)
744  .SetParam("beta", beta)
745  .SetParam("shape", shape)
746  .SetParam("dtype", Random_gammaDtypeValues[int(dtype)])
747  .CreateSymbol(symbol_name);
748 }
749 
753  kNone = 0,
754  kFloat16 = 1,
755  kFloat32 = 2,
756  kFloat64 = 3
757 };
758 
778 inline Symbol random_exponential(const std::string& symbol_name,
779  mx_float lam = 1,
780  Shape shape = Shape(),
781  const std::string& ctx = "",
783  static const char *Random_exponentialDtypeValues[] = {
784  "None",
785  "float16",
786  "float32",
787  "float64"
788  };
789  return Operator("random_exponential")
790  .SetParam("lam", lam)
791  .SetParam("shape", shape)
792  .SetParam("dtype", Random_exponentialDtypeValues[int(dtype)])
793  .CreateSymbol(symbol_name);
794 }
795 
799  kNone = 0,
800  kFloat16 = 1,
801  kFloat32 = 2,
802  kFloat64 = 3
803 };
804 
825 inline Symbol random_poisson(const std::string& symbol_name,
826  mx_float lam = 1,
827  Shape shape = Shape(),
828  const std::string& ctx = "",
830  static const char *Random_poissonDtypeValues[] = {
831  "None",
832  "float16",
833  "float32",
834  "float64"
835  };
836  return Operator("random_poisson")
837  .SetParam("lam", lam)
838  .SetParam("shape", shape)
839  .SetParam("dtype", Random_poissonDtypeValues[int(dtype)])
840  .CreateSymbol(symbol_name);
841 }
842 
846  kNone = 0,
847  kFloat16 = 1,
848  kFloat32 = 2,
849  kFloat64 = 3
850 };
851 
874 inline Symbol random_negative_binomial(const std::string& symbol_name,
875  int k = 1,
876  mx_float p = 1,
877  Shape shape = Shape(),
878  const std::string& ctx = "",
880  static const char *Random_negative_binomialDtypeValues[] = {
881  "None",
882  "float16",
883  "float32",
884  "float64"
885  };
886  return Operator("random_negative_binomial")
887  .SetParam("k", k)
888  .SetParam("p", p)
889  .SetParam("shape", shape)
890  .SetParam("dtype", Random_negative_binomialDtypeValues[int(dtype)])
891  .CreateSymbol(symbol_name);
892 }
893 
897  kNone = 0,
898  kFloat16 = 1,
899  kFloat32 = 2,
900  kFloat64 = 3
901 };
902 
926 inline Symbol random_generalized_negative_binomial(const std::string& symbol_name,
927  mx_float mu = 1,
928  mx_float alpha = 1,
929  Shape shape = Shape(),
930  const std::string& ctx = "",
932  static const char *Random_generalized_negative_binomialDtypeValues[] = {
933  "None",
934  "float16",
935  "float32",
936  "float64"
937  };
938  return Operator("random_generalized_negative_binomial")
939  .SetParam("mu", mu)
940  .SetParam("alpha", alpha)
941  .SetParam("shape", shape)
942  .SetParam("dtype", Random_generalized_negative_binomialDtypeValues[int(dtype)])
943  .CreateSymbol(symbol_name);
944 }
945 
968 inline Symbol broadcast_power(const std::string& symbol_name,
969  Symbol lhs,
970  Symbol rhs) {
971  return Operator("broadcast_power")
972  .SetInput("lhs", lhs)
973  .SetInput("rhs", rhs)
974  .CreateSymbol(symbol_name);
975 }
976 
1001 inline Symbol broadcast_maximum(const std::string& symbol_name,
1002  Symbol lhs,
1003  Symbol rhs) {
1004  return Operator("broadcast_maximum")
1005  .SetInput("lhs", lhs)
1006  .SetInput("rhs", rhs)
1007  .CreateSymbol(symbol_name);
1008 }
1009 
1034 inline Symbol broadcast_minimum(const std::string& symbol_name,
1035  Symbol lhs,
1036  Symbol rhs) {
1037  return Operator("broadcast_minimum")
1038  .SetInput("lhs", lhs)
1039  .SetInput("rhs", rhs)
1040  .CreateSymbol(symbol_name);
1041 }
1042 
1073 inline Symbol broadcast_hypot(const std::string& symbol_name,
1074  Symbol lhs,
1075  Symbol rhs) {
1076  return Operator("broadcast_hypot")
1077  .SetInput("lhs", lhs)
1078  .SetInput("rhs", rhs)
1079  .CreateSymbol(symbol_name);
1080 }
1081 
1156 inline Symbol Reshape(const std::string& symbol_name,
1157  Symbol data,
1158  Shape shape = Shape(),
1159  bool reverse = false,
1160  Shape target_shape = Shape(),
1161  bool keep_highest = false) {
1162  return Operator("Reshape")
1163  .SetParam("shape", shape)
1164  .SetParam("reverse", reverse)
1165  .SetParam("target_shape", target_shape)
1166  .SetParam("keep_highest", keep_highest)
1167  .SetInput("data", data)
1168  .CreateSymbol(symbol_name);
1169 }
1170 
1201 inline Symbol Flatten(const std::string& symbol_name,
1202  Symbol data) {
1203  return Operator("Flatten")
1204  .SetInput("data", data)
1205  .CreateSymbol(symbol_name);
1206 }
1207 
1244 inline Symbol transpose(const std::string& symbol_name,
1245  Symbol data,
1246  Shape axes = Shape()) {
1247  return Operator("transpose")
1248  .SetParam("axes", axes)
1249  .SetInput("data", data)
1250  .CreateSymbol(symbol_name);
1251 }
1252 
1268 inline Symbol expand_dims(const std::string& symbol_name,
1269  Symbol data,
1270  int axis) {
1271  return Operator("expand_dims")
1272  .SetParam("axis", axis)
1273  .SetInput("data", data)
1274  .CreateSymbol(symbol_name);
1275 }
1276 
1310 inline Symbol slice(const std::string& symbol_name,
1311  Symbol data,
1312  Shape begin,
1313  Shape end) {
1314  return Operator("slice")
1315  .SetParam("begin", begin)
1316  .SetParam("end", end)
1317  .SetInput("data", data)
1318  .CreateSymbol(symbol_name);
1319 }
1320 
1353 inline Symbol slice_axis(const std::string& symbol_name,
1354  Symbol data,
1355  int axis,
1356  int begin,
1357  dmlc::optional<int> end) {
1358  return Operator("slice_axis")
1359  .SetParam("axis", axis)
1360  .SetParam("begin", begin)
1361  .SetParam("end", end)
1362  .SetInput("data", data)
1363  .CreateSymbol(symbol_name);
1364 }
1365 
1397 inline Symbol dot(const std::string& symbol_name,
1398  Symbol lhs,
1399  Symbol rhs,
1400  bool transpose_a = false,
1401  bool transpose_b = false) {
1402  return Operator("dot")
1403  .SetParam("transpose_a", transpose_a)
1404  .SetParam("transpose_b", transpose_b)
1405  .SetInput("lhs", lhs)
1406  .SetInput("rhs", rhs)
1407  .CreateSymbol(symbol_name);
1408 }
1409 
1432 inline Symbol batch_dot(const std::string& symbol_name,
1433  Symbol lhs,
1434  Symbol rhs,
1435  bool transpose_a = false,
1436  bool transpose_b = false) {
1437  return Operator("batch_dot")
1438  .SetParam("transpose_a", transpose_a)
1439  .SetParam("transpose_b", transpose_b)
1440  .SetInput("lhs", lhs)
1441  .SetInput("rhs", rhs)
1442  .CreateSymbol(symbol_name);
1443 }
1444 
1468 inline Symbol clip(const std::string& symbol_name,
1469  Symbol data,
1470  mx_float a_min,
1471  mx_float a_max) {
1472  return Operator("clip")
1473  .SetParam("a_min", a_min)
1474  .SetParam("a_max", a_max)
1475  .SetInput("data", data)
1476  .CreateSymbol(symbol_name);
1477 }
1478 
1513 inline Symbol repeat(const std::string& symbol_name,
1514  Symbol data,
1515  int repeats,
1516  dmlc::optional<int> axis = dmlc::optional<int>()) {
1517  return Operator("repeat")
1518  .SetParam("repeats", repeats)
1519  .SetParam("axis", axis)
1520  .SetInput("data", data)
1521  .CreateSymbol(symbol_name);
1522 }
1523 
1569 inline Symbol tile(const std::string& symbol_name,
1570  Symbol data,
1571  Shape reps) {
1572  return Operator("tile")
1573  .SetParam("reps", reps)
1574  .SetInput("data", data)
1575  .CreateSymbol(symbol_name);
1576 }
1577 
1601 inline Symbol reverse(const std::string& symbol_name,
1602  Symbol data,
1603  Shape axis) {
1604  return Operator("reverse")
1605  .SetParam("axis", axis)
1606  .SetInput("data", data)
1607  .CreateSymbol(symbol_name);
1608 }
1609 
1633 inline Symbol stack(const std::string& symbol_name,
1634  const std::vector<Symbol>& data,
1635  int num_args,
1636  int axis = 0) {
1637  return Operator("stack")
1638  .SetParam("num_args", num_args)
1639  .SetParam("axis", axis)
1640 (data)
1641  .CreateSymbol(symbol_name);
1642 }
1643 
1661 inline Symbol zeros_like(const std::string& symbol_name,
1662  Symbol data) {
1663  return Operator("zeros_like")
1664  .SetInput("data", data)
1665  .CreateSymbol(symbol_name);
1666 }
1667 
1685 inline Symbol ones_like(const std::string& symbol_name,
1686  Symbol data) {
1687  return Operator("ones_like")
1688  .SetInput("data", data)
1689  .CreateSymbol(symbol_name);
1690 }
1691 
1719 inline Symbol broadcast_add(const std::string& symbol_name,
1720  Symbol lhs,
1721  Symbol rhs) {
1722  return Operator("broadcast_add")
1723  .SetInput("lhs", lhs)
1724  .SetInput("rhs", rhs)
1725  .CreateSymbol(symbol_name);
1726 }
1727 
1755 inline Symbol broadcast_sub(const std::string& symbol_name,
1756  Symbol lhs,
1757  Symbol rhs) {
1758  return Operator("broadcast_sub")
1759  .SetInput("lhs", lhs)
1760  .SetInput("rhs", rhs)
1761  .CreateSymbol(symbol_name);
1762 }
1763 
1786 inline Symbol broadcast_mul(const std::string& symbol_name,
1787  Symbol lhs,
1788  Symbol rhs) {
1789  return Operator("broadcast_mul")
1790  .SetInput("lhs", lhs)
1791  .SetInput("rhs", rhs)
1792  .CreateSymbol(symbol_name);
1793 }
1794 
1817 inline Symbol broadcast_div(const std::string& symbol_name,
1818  Symbol lhs,
1819  Symbol rhs) {
1820  return Operator("broadcast_div")
1821  .SetInput("lhs", lhs)
1822  .SetInput("rhs", rhs)
1823  .CreateSymbol(symbol_name);
1824 }
1825 
1848 inline Symbol broadcast_mod(const std::string& symbol_name,
1849  Symbol lhs,
1850  Symbol rhs) {
1851  return Operator("broadcast_mod")
1852  .SetInput("lhs", lhs)
1853  .SetInput("rhs", rhs)
1854  .CreateSymbol(symbol_name);
1855 }
1856 
1871 inline Symbol add_n(const std::string& symbol_name,
1872  const std::vector<Symbol>& args) {
1873  return Operator("add_n")
1874 (args)
1875  .CreateSymbol(symbol_name);
1876 }
1877 
1909 inline Symbol argmax(const std::string& symbol_name,
1910  Symbol data,
1911  dmlc::optional<int> axis = dmlc::optional<int>(),
1912  bool keepdims = false) {
1913  return Operator("argmax")
1914  .SetParam("axis", axis)
1915  .SetParam("keepdims", keepdims)
1916  .SetInput("data", data)
1917  .CreateSymbol(symbol_name);
1918 }
1919 
1951 inline Symbol argmin(const std::string& symbol_name,
1952  Symbol data,
1953  dmlc::optional<int> axis = dmlc::optional<int>(),
1954  bool keepdims = false) {
1955  return Operator("argmin")
1956  .SetParam("axis", axis)
1957  .SetParam("keepdims", keepdims)
1958  .SetInput("data", data)
1959  .CreateSymbol(symbol_name);
1960 }
1961 
1984 inline Symbol argmax_channel(const std::string& symbol_name,
1985  Symbol data) {
1986  return Operator("argmax_channel")
1987  .SetInput("data", data)
1988  .CreateSymbol(symbol_name);
1989 }
1990 
2036 inline Symbol pick(const std::string& symbol_name,
2037  Symbol data,
2038  Symbol index,
2039  dmlc::optional<int> axis = dmlc::optional<int>(),
2040  bool keepdims = false) {
2041  return Operator("pick")
2042  .SetParam("axis", axis)
2043  .SetParam("keepdims", keepdims)
2044  .SetInput("data", data)
2045  .SetInput("index", index)
2046  .CreateSymbol(symbol_name);
2047 }
2048 
2093 inline Symbol sum(const std::string& symbol_name,
2094  Symbol data,
2095  Shape axis = Shape(),
2096  bool keepdims = false,
2097  bool exclude = false) {
2098  return Operator("sum")
2099  .SetParam("axis", axis)
2100  .SetParam("keepdims", keepdims)
2101  .SetParam("exclude", exclude)
2102  .SetInput("data", data)
2103  .CreateSymbol(symbol_name);
2104 }
2105 
2130 inline Symbol mean(const std::string& symbol_name,
2131  Symbol data,
2132  Shape axis = Shape(),
2133  bool keepdims = false,
2134  bool exclude = false) {
2135  return Operator("mean")
2136  .SetParam("axis", axis)
2137  .SetParam("keepdims", keepdims)
2138  .SetParam("exclude", exclude)
2139  .SetInput("data", data)
2140  .CreateSymbol(symbol_name);
2141 }
2142 
2167 inline Symbol prod(const std::string& symbol_name,
2168  Symbol data,
2169  Shape axis = Shape(),
2170  bool keepdims = false,
2171  bool exclude = false) {
2172  return Operator("prod")
2173  .SetParam("axis", axis)
2174  .SetParam("keepdims", keepdims)
2175  .SetParam("exclude", exclude)
2176  .SetInput("data", data)
2177  .CreateSymbol(symbol_name);
2178 }
2179 
2206 inline Symbol nansum(const std::string& symbol_name,
2207  Symbol data,
2208  Shape axis = Shape(),
2209  bool keepdims = false,
2210  bool exclude = false) {
2211  return Operator("nansum")
2212  .SetParam("axis", axis)
2213  .SetParam("keepdims", keepdims)
2214  .SetParam("exclude", exclude)
2215  .SetInput("data", data)
2216  .CreateSymbol(symbol_name);
2217 }
2218 
2245 inline Symbol nanprod(const std::string& symbol_name,
2246  Symbol data,
2247  Shape axis = Shape(),
2248  bool keepdims = false,
2249  bool exclude = false) {
2250  return Operator("nanprod")
2251  .SetParam("axis", axis)
2252  .SetParam("keepdims", keepdims)
2253  .SetParam("exclude", exclude)
2254  .SetInput("data", data)
2255  .CreateSymbol(symbol_name);
2256 }
2257 
2282 inline Symbol max(const std::string& symbol_name,
2283  Symbol data,
2284  Shape axis = Shape(),
2285  bool keepdims = false,
2286  bool exclude = false) {
2287  return Operator("max")
2288  .SetParam("axis", axis)
2289  .SetParam("keepdims", keepdims)
2290  .SetParam("exclude", exclude)
2291  .SetInput("data", data)
2292  .CreateSymbol(symbol_name);
2293 }
2294 
2319 inline Symbol min(const std::string& symbol_name,
2320  Symbol data,
2321  Shape axis = Shape(),
2322  bool keepdims = false,
2323  bool exclude = false) {
2324  return Operator("min")
2325  .SetParam("axis", axis)
2326  .SetParam("keepdims", keepdims)
2327  .SetParam("exclude", exclude)
2328  .SetInput("data", data)
2329  .CreateSymbol(symbol_name);
2330 }
2331 
2361 inline Symbol broadcast_axis(const std::string& symbol_name,
2362  Symbol data,
2363  Shape axis = Shape(),
2364  Shape size = Shape()) {
2365  return Operator("broadcast_axis")
2366  .SetParam("axis", axis)
2367  .SetParam("size", size)
2368  .SetInput("data", data)
2369  .CreateSymbol(symbol_name);
2370 }
2371 
2400 inline Symbol broadcast_to(const std::string& symbol_name,
2401  Symbol data,
2402  Shape shape = Shape()) {
2403  return Operator("broadcast_to")
2404  .SetParam("shape", shape)
2405  .SetInput("data", data)
2406  .CreateSymbol(symbol_name);
2407 }
2408 
2426 inline Symbol norm(const std::string& symbol_name,
2427  Symbol data) {
2428  return Operator("norm")
2429  .SetInput("data", data)
2430  .CreateSymbol(symbol_name);
2431 }
2432 
2438 enum class TopkRetTyp {
2439  kBoth = 0,
2440  kIndices = 1,
2441  kMask = 2,
2442  kValue = 3
2443 };
2444 
2486 inline Symbol topk(const std::string& symbol_name,
2487  Symbol data,
2488  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2489  int k = 1,
2490  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
2491  bool is_ascend = false) {
2492  static const char *TopkRetTypValues[] = {
2493  "both",
2494  "indices",
2495  "mask",
2496  "value"
2497  };
2498  return Operator("topk")
2499  .SetParam("axis", axis)
2500  .SetParam("k", k)
2501  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
2502  .SetParam("is_ascend", is_ascend)
2503  .SetInput("data", data)
2504  .CreateSymbol(symbol_name);
2505 }
2506 
2539 inline Symbol sort(const std::string& symbol_name,
2540  Symbol data,
2541  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2542  bool is_ascend = true) {
2543  return Operator("sort")
2544  .SetParam("axis", axis)
2545  .SetParam("is_ascend", is_ascend)
2546  .SetInput("data", data)
2547  .CreateSymbol(symbol_name);
2548 }
2549 
2580 inline Symbol argsort(const std::string& symbol_name,
2581  Symbol data,
2582  dmlc::optional<int> axis = dmlc::optional<int>(-1),
2583  bool is_ascend = true) {
2584  return Operator("argsort")
2585  .SetParam("axis", axis)
2586  .SetParam("is_ascend", is_ascend)
2587  .SetInput("data", data)
2588  .CreateSymbol(symbol_name);
2589 }
2590 
2598 inline Symbol elemwise_add(const std::string& symbol_name,
2599  Symbol lhs,
2600  Symbol rhs) {
2601  return Operator("elemwise_add")
2602  .SetInput("lhs", lhs)
2603  .SetInput("rhs", rhs)
2604  .CreateSymbol(symbol_name);
2605 }
2606 
2620 inline Symbol relu(const std::string& symbol_name,
2621  Symbol data) {
2622  return Operator("relu")
2623  .SetInput("data", data)
2624  .CreateSymbol(symbol_name);
2625 }
2626 
2640 inline Symbol sigmoid(const std::string& symbol_name,
2641  Symbol data) {
2642  return Operator("sigmoid")
2643  .SetInput("data", data)
2644  .CreateSymbol(symbol_name);
2645 }
2646 
2680 inline Symbol BlockGrad(const std::string& symbol_name,
2681  Symbol data) {
2682  return Operator("BlockGrad")
2683  .SetInput("data", data)
2684  .CreateSymbol(symbol_name);
2685 }
2686 
2697 inline Symbol make_loss(const std::string& symbol_name,
2698  Symbol data) {
2699  return Operator("make_loss")
2700  .SetInput("data", data)
2701  .CreateSymbol(symbol_name);
2702 }
2703 
2706 enum class CastDtype {
2707  kFloat16 = 0,
2708  kFloat32 = 1,
2709  kFloat64 = 2,
2710  kInt32 = 3,
2711  kUint8 = 4
2712 };
2713 
2733 inline Symbol Cast(const std::string& symbol_name,
2734  Symbol data,
2735  CastDtype dtype) {
2736  static const char *CastDtypeValues[] = {
2737  "float16",
2738  "float32",
2739  "float64",
2740  "int32",
2741  "uint8"
2742  };
2743  return Operator("Cast")
2744  .SetParam("dtype", CastDtypeValues[int(dtype)])
2745  .SetInput("data", data)
2746  .CreateSymbol(symbol_name);
2747 }
2748 
2757 inline Symbol negative(const std::string& symbol_name,
2758  Symbol data) {
2759  return Operator("negative")
2760  .SetInput("data", data)
2761  .CreateSymbol(symbol_name);
2762 }
2763 
2780 inline Symbol reciprocal(const std::string& symbol_name,
2781  Symbol data) {
2782  return Operator("reciprocal")
2783  .SetInput("data", data)
2784  .CreateSymbol(symbol_name);
2785 }
2786 
2801 inline Symbol abs(const std::string& symbol_name,
2802  Symbol data) {
2803  return Operator("abs")
2804  .SetInput("data", data)
2805  .CreateSymbol(symbol_name);
2806 }
2807 
2822 inline Symbol sign(const std::string& symbol_name,
2823  Symbol data) {
2824  return Operator("sign")
2825  .SetInput("data", data)
2826  .CreateSymbol(symbol_name);
2827 }
2828 
2843 inline Symbol round(const std::string& symbol_name,
2844  Symbol data) {
2845  return Operator("round")
2846  .SetInput("data", data)
2847  .CreateSymbol(symbol_name);
2848 }
2849 
2868 inline Symbol rint(const std::string& symbol_name,
2869  Symbol data) {
2870  return Operator("rint")
2871  .SetInput("data", data)
2872  .CreateSymbol(symbol_name);
2873 }
2874 
2891 inline Symbol ceil(const std::string& symbol_name,
2892  Symbol data) {
2893  return Operator("ceil")
2894  .SetInput("data", data)
2895  .CreateSymbol(symbol_name);
2896 }
2897 
2914 inline Symbol floor(const std::string& symbol_name,
2915  Symbol data) {
2916  return Operator("floor")
2917  .SetInput("data", data)
2918  .CreateSymbol(symbol_name);
2919 }
2920 
2938 inline Symbol trunc(const std::string& symbol_name,
2939  Symbol data) {
2940  return Operator("trunc")
2941  .SetInput("data", data)
2942  .CreateSymbol(symbol_name);
2943 }
2944 
2959 inline Symbol fix(const std::string& symbol_name,
2960  Symbol data) {
2961  return Operator("fix")
2962  .SetInput("data", data)
2963  .CreateSymbol(symbol_name);
2964 }
2965 
2983 inline Symbol square(const std::string& symbol_name,
2984  Symbol data) {
2985  return Operator("square")
2986  .SetInput("data", data)
2987  .CreateSymbol(symbol_name);
2988 }
2989 
3007 inline Symbol sqrt(const std::string& symbol_name,
3008  Symbol data) {
3009  return Operator("sqrt")
3010  .SetInput("data", data)
3011  .CreateSymbol(symbol_name);
3012 }
3013 
3031 inline Symbol rsqrt(const std::string& symbol_name,
3032  Symbol data) {
3033  return Operator("rsqrt")
3034  .SetInput("data", data)
3035  .CreateSymbol(symbol_name);
3036 }
3037 
3055 inline Symbol exp(const std::string& symbol_name,
3056  Symbol data) {
3057  return Operator("exp")
3058  .SetInput("data", data)
3059  .CreateSymbol(symbol_name);
3060 }
3061 
3074 inline Symbol log(const std::string& symbol_name,
3075  Symbol data) {
3076  return Operator("log")
3077  .SetInput("data", data)
3078  .CreateSymbol(symbol_name);
3079 }
3080 
3093 inline Symbol log10(const std::string& symbol_name,
3094  Symbol data) {
3095  return Operator("log10")
3096  .SetInput("data", data)
3097  .CreateSymbol(symbol_name);
3098 }
3099 
3112 inline Symbol log2(const std::string& symbol_name,
3113  Symbol data) {
3114  return Operator("log2")
3115  .SetInput("data", data)
3116  .CreateSymbol(symbol_name);
3117 }
3118 
3134 inline Symbol sin(const std::string& symbol_name,
3135  Symbol data) {
3136  return Operator("sin")
3137  .SetInput("data", data)
3138  .CreateSymbol(symbol_name);
3139 }
3140 
3154 inline Symbol log1p(const std::string& symbol_name,
3155  Symbol data) {
3156  return Operator("log1p")
3157  .SetInput("data", data)
3158  .CreateSymbol(symbol_name);
3159 }
3160 
3173 inline Symbol expm1(const std::string& symbol_name,
3174  Symbol data) {
3175  return Operator("expm1")
3176  .SetInput("data", data)
3177  .CreateSymbol(symbol_name);
3178 }
3179 
3195 inline Symbol cos(const std::string& symbol_name,
3196  Symbol data) {
3197  return Operator("cos")
3198  .SetInput("data", data)
3199  .CreateSymbol(symbol_name);
3200 }
3201 
3217 inline Symbol tan(const std::string& symbol_name,
3218  Symbol data) {
3219  return Operator("tan")
3220  .SetInput("data", data)
3221  .CreateSymbol(symbol_name);
3222 }
3223 
3240 inline Symbol arcsin(const std::string& symbol_name,
3241  Symbol data) {
3242  return Operator("arcsin")
3243  .SetInput("data", data)
3244  .CreateSymbol(symbol_name);
3245 }
3246 
3263 inline Symbol arccos(const std::string& symbol_name,
3264  Symbol data) {
3265  return Operator("arccos")
3266  .SetInput("data", data)
3267  .CreateSymbol(symbol_name);
3268 }
3269 
3285 inline Symbol arctan(const std::string& symbol_name,
3286  Symbol data) {
3287  return Operator("arctan")
3288  .SetInput("data", data)
3289  .CreateSymbol(symbol_name);
3290 }
3291 
3305 inline Symbol degrees(const std::string& symbol_name,
3306  Symbol data) {
3307  return Operator("degrees")
3308  .SetInput("data", data)
3309  .CreateSymbol(symbol_name);
3310 }
3311 
3325 inline Symbol radians(const std::string& symbol_name,
3326  Symbol data) {
3327  return Operator("radians")
3328  .SetInput("data", data)
3329  .CreateSymbol(symbol_name);
3330 }
3331 
3345 inline Symbol sinh(const std::string& symbol_name,
3346  Symbol data) {
3347  return Operator("sinh")
3348  .SetInput("data", data)
3349  .CreateSymbol(symbol_name);
3350 }
3351 
3365 inline Symbol cosh(const std::string& symbol_name,
3366  Symbol data) {
3367  return Operator("cosh")
3368  .SetInput("data", data)
3369  .CreateSymbol(symbol_name);
3370 }
3371 
3385 inline Symbol tanh(const std::string& symbol_name,
3386  Symbol data) {
3387  return Operator("tanh")
3388  .SetInput("data", data)
3389  .CreateSymbol(symbol_name);
3390 }
3391 
3401 inline Symbol arcsinh(const std::string& symbol_name,
3402  Symbol data) {
3403  return Operator("arcsinh")
3404  .SetInput("data", data)
3405  .CreateSymbol(symbol_name);
3406 }
3407 
3417 inline Symbol arccosh(const std::string& symbol_name,
3418  Symbol data) {
3419  return Operator("arccosh")
3420  .SetInput("data", data)
3421  .CreateSymbol(symbol_name);
3422 }
3423 
3433 inline Symbol arctanh(const std::string& symbol_name,
3434  Symbol data) {
3435  return Operator("arctanh")
3436  .SetInput("data", data)
3437  .CreateSymbol(symbol_name);
3438 }
3439 
3448 inline Symbol gamma(const std::string& symbol_name,
3449  Symbol data) {
3450  return Operator("gamma")
3451  .SetInput("data", data)
3452  .CreateSymbol(symbol_name);
3453 }
3454 
3463 inline Symbol gammaln(const std::string& symbol_name,
3464  Symbol data) {
3465  return Operator("gammaln")
3466  .SetInput("data", data)
3467  .CreateSymbol(symbol_name);
3468 }
3469 
3472 enum class EmbeddingDtype {
3473  kFloat16 = 0,
3474  kFloat32 = 1,
3475  kFloat64 = 2,
3476  kInt32 = 3,
3477  kUint8 = 4
3478 };
3479 
3531 inline Symbol Embedding(const std::string& symbol_name,
3532  Symbol data,
3533  Symbol weight,
3534  int input_dim,
3535  int output_dim,
3537  static const char *EmbeddingDtypeValues[] = {
3538  "float16",
3539  "float32",
3540  "float64",
3541  "int32",
3542  "uint8"
3543  };
3544  return Operator("Embedding")
3545  .SetParam("input_dim", input_dim)
3546  .SetParam("output_dim", output_dim)
3547  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
3548  .SetInput("data", data)
3549  .SetInput("weight", weight)
3550  .CreateSymbol(symbol_name);
3551 }
3552 
3557 enum class TakeMode {
3558  kClip = 0,
3559  kRaise = 1,
3560  kWrap = 2
3561 };
3562 
3602 inline Symbol take(const std::string& symbol_name,
3603  Symbol a,
3604  Symbol indices,
3605  int axis = 0,
3606  TakeMode mode = TakeMode::kClip) {
3607  static const char *TakeModeValues[] = {
3608  "clip",
3609  "raise",
3610  "wrap"
3611  };
3612  return Operator("take")
3613  .SetParam("axis", axis)
3614  .SetParam("mode", TakeModeValues[int(mode)])
3615  .SetInput("a", a)
3616  .SetInput("indices", indices)
3617  .CreateSymbol(symbol_name);
3618 }
3619 
3648 inline Symbol batch_take(const std::string& symbol_name,
3649  Symbol a,
3650  Symbol indices) {
3651  return Operator("batch_take")
3652  .SetInput("a", a)
3653  .SetInput("indices", indices)
3654  .CreateSymbol(symbol_name);
3655 }
3656 
3659 enum class One_hotDtype {
3660  kFloat16 = 0,
3661  kFloat32 = 1,
3662  kFloat64 = 2,
3663  kInt32 = 3,
3664  kUint8 = 4
3665 };
3666 
3711 inline Symbol one_hot(const std::string& symbol_name,
3712  Symbol indices,
3713  int depth,
3714  double on_value = 1,
3715  double off_value = 0,
3717  static const char *One_hotDtypeValues[] = {
3718  "float16",
3719  "float32",
3720  "float64",
3721  "int32",
3722  "uint8"
3723  };
3724  return Operator("one_hot")
3725  .SetParam("depth", depth)
3726  .SetParam("on_value", on_value)
3727  .SetParam("off_value", off_value)
3728  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
3729  .SetInput("indices", indices)
3730  .CreateSymbol(symbol_name);
3731 }
3732 
3755 inline Symbol broadcast_equal(const std::string& symbol_name,
3756  Symbol lhs,
3757  Symbol rhs) {
3758  return Operator("broadcast_equal")
3759  .SetInput("lhs", lhs)
3760  .SetInput("rhs", rhs)
3761  .CreateSymbol(symbol_name);
3762 }
3763 
3786 inline Symbol broadcast_not_equal(const std::string& symbol_name,
3787  Symbol lhs,
3788  Symbol rhs) {
3789  return Operator("broadcast_not_equal")
3790  .SetInput("lhs", lhs)
3791  .SetInput("rhs", rhs)
3792  .CreateSymbol(symbol_name);
3793 }
3794 
3817 inline Symbol broadcast_greater(const std::string& symbol_name,
3818  Symbol lhs,
3819  Symbol rhs) {
3820  return Operator("broadcast_greater")
3821  .SetInput("lhs", lhs)
3822  .SetInput("rhs", rhs)
3823  .CreateSymbol(symbol_name);
3824 }
3825 
3848 inline Symbol broadcast_greater_equal(const std::string& symbol_name,
3849  Symbol lhs,
3850  Symbol rhs) {
3851  return Operator("broadcast_greater_equal")
3852  .SetInput("lhs", lhs)
3853  .SetInput("rhs", rhs)
3854  .CreateSymbol(symbol_name);
3855 }
3856 
3879 inline Symbol broadcast_lesser(const std::string& symbol_name,
3880  Symbol lhs,
3881  Symbol rhs) {
3882  return Operator("broadcast_lesser")
3883  .SetInput("lhs", lhs)
3884  .SetInput("rhs", rhs)
3885  .CreateSymbol(symbol_name);
3886 }
3887 
3910 inline Symbol broadcast_lesser_equal(const std::string& symbol_name,
3911  Symbol lhs,
3912  Symbol rhs) {
3913  return Operator("broadcast_lesser_equal")
3914  .SetInput("lhs", lhs)
3915  .SetInput("rhs", rhs)
3916  .CreateSymbol(symbol_name);
3917 }
3918 
3964 inline Symbol linalg_gemm(const std::string& symbol_name,
3965  Symbol A,
3966  Symbol B,
3967  Symbol C,
3968  bool transpose_a = false,
3969  bool transpose_b = false,
3970  double alpha = 1,
3971  double beta = 1) {
3972  return Operator("linalg_gemm")
3973  .SetParam("transpose_a", transpose_a)
3974  .SetParam("transpose_b", transpose_b)
3975  .SetParam("alpha", alpha)
3976  .SetParam("beta", beta)
3977  .SetInput("A", A)
3978  .SetInput("B", B)
3979  .SetInput("C", C)
3980  .CreateSymbol(symbol_name);
3981 }
3982 
4024 inline Symbol linalg_gemm2(const std::string& symbol_name,
4025  Symbol A,
4026  Symbol B,
4027  bool transpose_a = false,
4028  bool transpose_b = false,
4029  double alpha = 1) {
4030  return Operator("linalg_gemm2")
4031  .SetParam("transpose_a", transpose_a)
4032  .SetParam("transpose_b", transpose_b)
4033  .SetParam("alpha", alpha)
4034  .SetInput("A", A)
4035  .SetInput("B", B)
4036  .CreateSymbol(symbol_name);
4037 }
4038 
4074 inline Symbol linalg_potrf(const std::string& symbol_name,
4075  Symbol A) {
4076  return Operator("linalg_potrf")
4077  .SetInput("A", A)
4078  .CreateSymbol(symbol_name);
4079 }
4080 
4116 inline Symbol linalg_potri(const std::string& symbol_name,
4117  Symbol A) {
4118  return Operator("linalg_potri")
4119  .SetInput("A", A)
4120  .CreateSymbol(symbol_name);
4121 }
4122 
4170 inline Symbol linalg_trmm(const std::string& symbol_name,
4171  Symbol A,
4172  Symbol B,
4173  bool transpose = false,
4174  bool rightside = false,
4175  double alpha = 1) {
4176  return Operator("linalg_trmm")
4177  .SetParam("transpose", transpose)
4178  .SetParam("rightside", rightside)
4179  .SetParam("alpha", alpha)
4180  .SetInput("A", A)
4181  .SetInput("B", B)
4182  .CreateSymbol(symbol_name);
4183 }
4184 
4232 inline Symbol linalg_trsm(const std::string& symbol_name,
4233  Symbol A,
4234  Symbol B,
4235  bool transpose = false,
4236  bool rightside = false,
4237  double alpha = 1) {
4238  return Operator("linalg_trsm")
4239  .SetParam("transpose", transpose)
4240  .SetParam("rightside", rightside)
4241  .SetParam("alpha", alpha)
4242  .SetInput("A", A)
4243  .SetInput("B", B)
4244  .CreateSymbol(symbol_name);
4245 }
4246 
4275 inline Symbol linalg_sumlogdiag(const std::string& symbol_name,
4276  Symbol A) {
4277  return Operator("linalg_sumlogdiag")
4278  .SetInput("A", A)
4279  .CreateSymbol(symbol_name);
4280 }
4281 
4298 inline Symbol where(const std::string& symbol_name,
4299  Symbol condition,
4300  Symbol x,
4301  Symbol y) {
4302  return Operator("where")
4303  .SetInput("condition", condition)
4304  .SetInput("x", x)
4305  .SetInput("y", y)
4306  .CreateSymbol(symbol_name);
4307 }
4308 
4334 inline Symbol smooth_l1(const std::string& symbol_name,
4335  Symbol data,
4336  mx_float scalar) {
4337  return Operator("smooth_l1")
4338  .SetParam("scalar", scalar)
4339  .SetInput("data", data)
4340  .CreateSymbol(symbol_name);
4341 }
4342 
4358 inline Symbol Custom(const std::string& symbol_name,
4359  const std::vector<Symbol>& data,
4360  const std::string& op_type) {
4361  return Operator("Custom")
4362 (data)
4363  .CreateSymbol(symbol_name);
4364 }
4365 
4394 inline Symbol SwapAxis(const std::string& symbol_name,
4395  Symbol data,
4396  uint32_t dim1 = 0,
4397  uint32_t dim2 = 0) {
4398  return Operator("SwapAxis")
4399  .SetParam("dim1", dim1)
4400  .SetParam("dim2", dim2)
4401  .SetInput("data", data)
4402  .CreateSymbol(symbol_name);
4403 }
4404 
4407 enum class LeakyReLUActType {
4408  kElu = 0,
4409  kLeaky = 1,
4410  kPrelu = 2,
4411  kRrelu = 3
4412 };
4413 
4440 inline Symbol LeakyReLU(const std::string& symbol_name,
4441  Symbol data,
4443  mx_float slope = 0.25,
4444  mx_float lower_bound = 0.125,
4445  mx_float upper_bound = 0.334) {
4446  static const char *LeakyReLUActTypeValues[] = {
4447  "elu",
4448  "leaky",
4449  "prelu",
4450  "rrelu"
4451  };
4452  return Operator("LeakyReLU")
4453  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
4454  .SetParam("slope", slope)
4455  .SetParam("lower_bound", lower_bound)
4456  .SetParam("upper_bound", upper_bound)
4457  .SetInput("data", data)
4458  .CreateSymbol(symbol_name);
4459 }
4460 
4516 inline Symbol BatchNorm_v1(const std::string& symbol_name,
4517  Symbol data,
4518  Symbol gamma,
4519  Symbol beta,
4520  mx_float eps = 0.001,
4521  mx_float momentum = 0.9,
4522  bool fix_gamma = true,
4523  bool use_global_stats = false,
4524  bool output_mean_var = false) {
4525  return Operator("BatchNorm_v1")
4526  .SetParam("eps", eps)
4527  .SetParam("momentum", momentum)
4528  .SetParam("fix_gamma", fix_gamma)
4529  .SetParam("use_global_stats", use_global_stats)
4530  .SetParam("output_mean_var", output_mean_var)
4531  .SetInput("data", data)
4532  .SetInput("gamma", gamma)
4533  .SetInput("beta", beta)
4534  .CreateSymbol(symbol_name);
4535 }
4536 
4578 inline Symbol Concat(const std::string& symbol_name,
4579  const std::vector<Symbol>& data,
4580  int num_args,
4581  int dim = 1) {
4582  return Operator("Concat")
4583  .SetParam("num_args", num_args)
4584  .SetParam("dim", dim)
4585 (data)
4586  .CreateSymbol(symbol_name);
4587 }
4588 
4610 inline Symbol sgd_update(const std::string& symbol_name,
4611  Symbol weight,
4612  Symbol grad,
4613  mx_float lr,
4614  mx_float wd = 0,
4615  mx_float rescale_grad = 1,
4616  mx_float clip_gradient = -1) {
4617  return Operator("sgd_update")
4618  .SetParam("lr", lr)
4619  .SetParam("wd", wd)
4620  .SetParam("rescale_grad", rescale_grad)
4621  .SetParam("clip_gradient", clip_gradient)
4622  .SetInput("weight", weight)
4623  .SetInput("grad", grad)
4624  .CreateSymbol(symbol_name);
4625 }
4626 
4662 inline Symbol sgd_mom_update(const std::string& symbol_name,
4663  Symbol weight,
4664  Symbol grad,
4665  Symbol mom,
4666  mx_float lr,
4667  mx_float momentum = 0,
4668  mx_float wd = 0,
4669  mx_float rescale_grad = 1,
4670  mx_float clip_gradient = -1) {
4671  return Operator("sgd_mom_update")
4672  .SetParam("lr", lr)
4673  .SetParam("momentum", momentum)
4674  .SetParam("wd", wd)
4675  .SetParam("rescale_grad", rescale_grad)
4676  .SetParam("clip_gradient", clip_gradient)
4677  .SetInput("weight", weight)
4678  .SetInput("grad", grad)
4679  .SetInput("mom", mom)
4680  .CreateSymbol(symbol_name);
4681 }
4682 
4697 inline Symbol mp_sgd_update(const std::string& symbol_name,
4698  Symbol weight,
4699  Symbol grad,
4700  Symbol weight32,
4701  mx_float lr,
4702  mx_float wd = 0,
4703  mx_float rescale_grad = 1,
4704  mx_float clip_gradient = -1) {
4705  return Operator("mp_sgd_update")
4706  .SetParam("lr", lr)
4707  .SetParam("wd", wd)
4708  .SetParam("rescale_grad", rescale_grad)
4709  .SetParam("clip_gradient", clip_gradient)
4710  .SetInput("weight", weight)
4711  .SetInput("grad", grad)
4712  .SetInput("weight32", weight32)
4713  .CreateSymbol(symbol_name);
4714 }
4715 
4732 inline Symbol mp_sgd_mom_update(const std::string& symbol_name,
4733  Symbol weight,
4734  Symbol grad,
4735  Symbol mom,
4736  Symbol weight32,
4737  mx_float lr,
4738  mx_float momentum = 0,
4739  mx_float wd = 0,
4740  mx_float rescale_grad = 1,
4741  mx_float clip_gradient = -1) {
4742  return Operator("mp_sgd_mom_update")
4743  .SetParam("lr", lr)
4744  .SetParam("momentum", momentum)
4745  .SetParam("wd", wd)
4746  .SetParam("rescale_grad", rescale_grad)
4747  .SetParam("clip_gradient", clip_gradient)
4748  .SetInput("weight", weight)
4749  .SetInput("grad", grad)
4750  .SetInput("mom", mom)
4751  .SetInput("weight32", weight32)
4752  .CreateSymbol(symbol_name);
4753 }
4754 
4794 inline Symbol adam_update(const std::string& symbol_name,
4795  Symbol weight,
4796  Symbol grad,
4797  Symbol mean,
4798  Symbol var,
4799  mx_float lr,
4800  mx_float beta1 = 0.9,
4801  mx_float beta2 = 0.999,
4802  mx_float epsilon = 1e-08,
4803  mx_float wd = 0,
4804  mx_float rescale_grad = 1,
4805  mx_float clip_gradient = -1) {
4806  return Operator("adam_update")
4807  .SetParam("lr", lr)
4808  .SetParam("beta1", beta1)
4809  .SetParam("beta2", beta2)
4810  .SetParam("epsilon", epsilon)
4811  .SetParam("wd", wd)
4812  .SetParam("rescale_grad", rescale_grad)
4813  .SetParam("clip_gradient", clip_gradient)
4814  .SetInput("weight", weight)
4815  .SetInput("grad", grad)
4816  .SetInput("mean", mean)
4817  .SetInput("var", var)
4818  .CreateSymbol(symbol_name);
4819 }
4820 
4874 inline Symbol rmsprop_update(const std::string& symbol_name,
4875  Symbol weight,
4876  Symbol grad,
4877  Symbol n,
4878  mx_float lr,
4879  mx_float gamma1 = 0.95,
4880  mx_float epsilon = 1e-08,
4881  mx_float wd = 0,
4882  mx_float rescale_grad = 1,
4883  mx_float clip_gradient = -1,
4884  mx_float clip_weights = -1) {
4885  return Operator("rmsprop_update")
4886  .SetParam("lr", lr)
4887  .SetParam("gamma1", gamma1)
4888  .SetParam("epsilon", epsilon)
4889  .SetParam("wd", wd)
4890  .SetParam("rescale_grad", rescale_grad)
4891  .SetParam("clip_gradient", clip_gradient)
4892  .SetParam("clip_weights", clip_weights)
4893  .SetInput("weight", weight)
4894  .SetInput("grad", grad)
4895  .SetInput("n", n)
4896  .CreateSymbol(symbol_name);
4897 }
4898 
4944 inline Symbol rmspropalex_update(const std::string& symbol_name,
4945  Symbol weight,
4946  Symbol grad,
4947  Symbol n,
4948  Symbol g,
4949  Symbol delta,
4950  mx_float lr,
4951  mx_float gamma1 = 0.95,
4952  mx_float gamma2 = 0.9,
4953  mx_float epsilon = 1e-08,
4954  mx_float wd = 0,
4955  mx_float rescale_grad = 1,
4956  mx_float clip_gradient = -1,
4957  mx_float clip_weights = -1) {
4958  return Operator("rmspropalex_update")
4959  .SetParam("lr", lr)
4960  .SetParam("gamma1", gamma1)
4961  .SetParam("gamma2", gamma2)
4962  .SetParam("epsilon", epsilon)
4963  .SetParam("wd", wd)
4964  .SetParam("rescale_grad", rescale_grad)
4965  .SetParam("clip_gradient", clip_gradient)
4966  .SetParam("clip_weights", clip_weights)
4967  .SetInput("weight", weight)
4968  .SetInput("grad", grad)
4969  .SetInput("n", n)
4970  .SetInput("g", g)
4971  .SetInput("delta", delta)
4972  .CreateSymbol(symbol_name);
4973 }
4974 
4978 enum class PadMode {
4979  kConstant = 0,
4980  kEdge = 1,
4981  kReflect = 2
4982 };
4983 
5080 inline Symbol Pad(const std::string& symbol_name,
5081  Symbol data,
5082  PadMode mode,
5083  Shape pad_width,
5084  double constant_value = 0) {
5085  static const char *PadModeValues[] = {
5086  "constant",
5087  "edge",
5088  "reflect"
5089  };
5090  return Operator("Pad")
5091  .SetParam("mode", PadModeValues[int(mode)])
5092  .SetParam("pad_width", pad_width)
5093  .SetParam("constant_value", constant_value)
5094  .SetInput("data", data)
5095  .CreateSymbol(symbol_name);
5096 }
5097 
5107 inline Symbol IdentityAttachKLSparseReg(const std::string& symbol_name,
5108  Symbol data,
5109  mx_float sparseness_target = 0.1,
5110  mx_float penalty = 0.001,
5111  mx_float momentum = 0.9) {
5112  return Operator("IdentityAttachKLSparseReg")
5113  .SetParam("sparseness_target", sparseness_target)
5114  .SetParam("penalty", penalty)
5115  .SetParam("momentum", momentum)
5116  .SetInput("data", data)
5117  .CreateSymbol(symbol_name);
5118 }
5119 
5191 inline Symbol SliceChannel(const std::string& symbol_name,
5192  Symbol data,
5193  int num_outputs,
5194  int axis = 1,
5195  bool squeeze_axis = false) {
5196  return Operator("SliceChannel")
5197  .SetParam("num_outputs", num_outputs)
5198  .SetParam("axis", axis)
5199  .SetParam("squeeze_axis", squeeze_axis)
5200  .SetInput("data", data)
5201  .CreateSymbol(symbol_name);
5202 }
5203 
5241 inline Symbol softmax_cross_entropy(const std::string& symbol_name,
5242  Symbol data,
5243  Symbol label) {
5244  return Operator("softmax_cross_entropy")
5245  .SetInput("data", data)
5246  .SetInput("label", label)
5247  .CreateSymbol(symbol_name);
5248 }
5249 
5253  kBilinear = 0,
5254  kNearest = 1
5255 };
5256 
5261  kConcat = 0,
5262  kSum = 1
5263 };
5264 
5280 inline Symbol UpSampling(const std::string& symbol_name,
5281  const std::vector<Symbol>& data,
5282  uint32_t scale,
5283  UpSamplingSampleType sample_type,
5284  int num_args,
5285  uint32_t num_filter = 0,
5287  uint64_t workspace = 512) {
5288  static const char *UpSamplingSampleTypeValues[] = {
5289  "bilinear",
5290  "nearest"
5291  };
5292  static const char *UpSamplingMultiInputModeValues[] = {
5293  "concat",
5294  "sum"
5295  };
5296  return Operator("UpSampling")
5297  .SetParam("scale", scale)
5298  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
5299  .SetParam("num_args", num_args)
5300  .SetParam("num_filter", num_filter)
5301  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
5302  .SetParam("workspace", workspace)
5303 (data)
5304  .CreateSymbol(symbol_name);
5305 }
5306 
5370 inline Symbol BatchNorm(const std::string& symbol_name,
5371  Symbol data,
5372  Symbol gamma,
5373  Symbol beta,
5374  Symbol moving_mean,
5375  Symbol moving_var,
5376  double eps = 0.001,
5377  mx_float momentum = 0.9,
5378  bool fix_gamma = true,
5379  bool use_global_stats = false,
5380  bool output_mean_var = false,
5381  int axis = 1,
5382  bool cudnn_off = false) {
5383  return Operator("BatchNorm")
5384  .SetParam("eps", eps)
5385  .SetParam("momentum", momentum)
5386  .SetParam("fix_gamma", fix_gamma)
5387  .SetParam("use_global_stats", use_global_stats)
5388  .SetParam("output_mean_var", output_mean_var)
5389  .SetParam("axis", axis)
5390  .SetParam("cudnn_off", cudnn_off)
5391  .SetInput("data", data)
5392  .SetInput("gamma", gamma)
5393  .SetInput("beta", beta)
5394  .SetInput("moving_mean", moving_mean)
5395  .SetInput("moving_var", moving_var)
5396  .CreateSymbol(symbol_name);
5397 }
5398 
5449 inline Symbol InstanceNorm(const std::string& symbol_name,
5450  Symbol data,
5451  Symbol gamma,
5452  Symbol beta,
5453  mx_float eps = 0.001) {
5454  return Operator("InstanceNorm")
5455  .SetParam("eps", eps)
5456  .SetInput("data", data)
5457  .SetInput("gamma", gamma)
5458  .SetInput("beta", beta)
5459  .CreateSymbol(symbol_name);
5460 }
5461 
5464 enum class RNNMode {
5465  kGru = 0,
5466  kLstm = 1,
5467  kRnn_relu = 2,
5468  kRnn_tanh = 3
5469 };
5470 
5486 inline Symbol RNN(const std::string& symbol_name,
5487  Symbol data,
5488  Symbol parameters,
5489  Symbol state,
5490  Symbol state_cell,
5491  uint32_t state_size,
5492  uint32_t num_layers,
5493  RNNMode mode,
5494  bool bidirectional = false,
5495  mx_float p = 0,
5496  bool state_outputs = false) {
5497  static const char *RNNModeValues[] = {
5498  "gru",
5499  "lstm",
5500  "rnn_relu",
5501  "rnn_tanh"
5502  };
5503  return Operator("RNN")
5504  .SetParam("state_size", state_size)
5505  .SetParam("num_layers", num_layers)
5506  .SetParam("mode", RNNModeValues[int(mode)])
5507  .SetParam("bidirectional", bidirectional)
5508  .SetParam("p", p)
5509  .SetParam("state_outputs", state_outputs)
5510  .SetInput("data", data)
5511  .SetInput("parameters", parameters)
5512  .SetInput("state", state)
5513  .SetInput("state_cell", state_cell)
5514  .CreateSymbol(symbol_name);
5515 }
5516 
5527  kNone = 0,
5528  kFastest = 1,
5529  kLimited_workspace = 2,
5530  kOff = 3
5531 };
5532 
5537  kNone = 0,
5538  kNCDHW = 1,
5539  kNCHW = 2,
5540  kNDHWC = 3,
5541  kNHWC = 4
5542 };
5543 
5572 inline Symbol Convolution_v1(const std::string& symbol_name,
5573  Symbol data,
5574  Symbol weight,
5575  Symbol bias,
5576  Shape kernel,
5577  uint32_t num_filter,
5578  Shape stride = Shape(),
5579  Shape dilate = Shape(),
5580  Shape pad = Shape(),
5581  uint32_t num_group = 1,
5582  uint64_t workspace = 1024,
5583  bool no_bias = false,
5585  bool cudnn_off = false,
5587  static const char *Convolution_v1CudnnTuneValues[] = {
5588  "None",
5589  "fastest",
5590  "limited_workspace",
5591  "off"
5592  };
5593  static const char *Convolution_v1LayoutValues[] = {
5594  "None",
5595  "NCDHW",
5596  "NCHW",
5597  "NDHWC",
5598  "NHWC"
5599  };
5600  return Operator("Convolution_v1")
5601  .SetParam("kernel", kernel)
5602  .SetParam("num_filter", num_filter)
5603  .SetParam("stride", stride)
5604  .SetParam("dilate", dilate)
5605  .SetParam("pad", pad)
5606  .SetParam("num_group", num_group)
5607  .SetParam("workspace", workspace)
5608  .SetParam("no_bias", no_bias)
5609  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
5610  .SetParam("cudnn_off", cudnn_off)
5611  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
5612  .SetInput("data", data)
5613  .SetInput("weight", weight)
5614  .SetInput("bias", bias)
5615  .CreateSymbol(symbol_name);
5616 }
5617 
5638 inline Symbol Crop(const std::string& symbol_name,
5639  const std::vector<Symbol>& data,
5640  int num_args,
5641  Shape offset = Shape(0,0),
5642  Shape h_w = Shape(0,0),
5643  bool center_crop = false) {
5644  return Operator("Crop")
5645  .SetParam("num_args", num_args)
5646  .SetParam("offset", offset)
5647  .SetParam("h_w", h_w)
5648  .SetParam("center_crop", center_crop)
5649 (data)
5650  .CreateSymbol(symbol_name);
5651 }
5652 
5656  kAffine = 0
5657 };
5658 
5662  kBilinear = 0
5663 };
5664 
5675 inline Symbol SpatialTransformer(const std::string& symbol_name,
5676  Symbol data,
5677  Symbol loc,
5678  SpatialTransformerTransformType transform_type,
5679  SpatialTransformerSamplerType sampler_type,
5680  Shape target_shape = Shape(0,0)) {
5681  static const char *SpatialTransformerTransformTypeValues[] = {
5682  "affine"
5683  };
5684  static const char *SpatialTransformerSamplerTypeValues[] = {
5685  "bilinear"
5686  };
5687  return Operator("SpatialTransformer")
5688  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
5689  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
5690  .SetParam("target_shape", target_shape)
5691  .SetInput("data", data)
5692  .SetInput("loc", loc)
5693  .CreateSymbol(symbol_name);
5694 }
5695 
5699  kNone = 0,
5700  kFastest = 1,
5701  kLimited_workspace = 2,
5702  kOff = 3
5703 };
5704 
5708  kNone = 0,
5709  kNCDHW = 1,
5710  kNCHW = 2,
5711  kNCW = 3,
5712  kNDHWC = 4,
5713  kNHWC = 5
5714 };
5715 
5742 inline Symbol Deconvolution(const std::string& symbol_name,
5743  Symbol data,
5744  Symbol weight,
5745  Symbol bias,
5746  Shape kernel,
5747  uint32_t num_filter,
5748  Shape stride = Shape(),
5749  Shape dilate = Shape(),
5750  Shape pad = Shape(),
5751  Shape adj = Shape(),
5752  Shape target_shape = Shape(),
5753  uint32_t num_group = 1,
5754  uint64_t workspace = 512,
5755  bool no_bias = true,
5757  bool cudnn_off = false,
5759  static const char *DeconvolutionCudnnTuneValues[] = {
5760  "None",
5761  "fastest",
5762  "limited_workspace",
5763  "off"
5764  };
5765  static const char *DeconvolutionLayoutValues[] = {
5766  "None",
5767  "NCDHW",
5768  "NCHW",
5769  "NCW",
5770  "NDHWC",
5771  "NHWC"
5772  };
5773  return Operator("Deconvolution")
5774  .SetParam("kernel", kernel)
5775  .SetParam("num_filter", num_filter)
5776  .SetParam("stride", stride)
5777  .SetParam("dilate", dilate)
5778  .SetParam("pad", pad)
5779  .SetParam("adj", adj)
5780  .SetParam("target_shape", target_shape)
5781  .SetParam("num_group", num_group)
5782  .SetParam("workspace", workspace)
5783  .SetParam("no_bias", no_bias)
5784  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
5785  .SetParam("cudnn_off", cudnn_off)
5786  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
5787  .SetInput("data", data)
5788  .SetInput("weight", weight)
5789  .SetInput("bias", bias)
5790  .CreateSymbol(symbol_name);
5791 }
5792 
5796  kBatch = 0,
5797  kNull = 1,
5798  kValid = 2
5799 };
5800 
5893 inline Symbol SoftmaxOutput(const std::string& symbol_name,
5894  Symbol data,
5895  Symbol label,
5896  mx_float grad_scale = 1,
5897  mx_float ignore_label = -1,
5898  bool multi_output = false,
5899  bool use_ignore = false,
5900  bool preserve_shape = false,
5902  bool out_grad = false) {
5903  static const char *SoftmaxOutputNormalizationValues[] = {
5904  "batch",
5905  "null",
5906  "valid"
5907  };
5908  return Operator("SoftmaxOutput")
5909  .SetParam("grad_scale", grad_scale)
5910  .SetParam("ignore_label", ignore_label)
5911  .SetParam("multi_output", multi_output)
5912  .SetParam("use_ignore", use_ignore)
5913  .SetParam("preserve_shape", preserve_shape)
5914  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
5915  .SetParam("out_grad", out_grad)
5916  .SetInput("data", data)
5917  .SetInput("label", label)
5918  .CreateSymbol(symbol_name);
5919 }
5920 
5924  kBatch = 0,
5925  kNull = 1,
5926  kValid = 2
5927 };
5928 
5953 inline Symbol Softmax(const std::string& symbol_name,
5954  Symbol data,
5955  mx_float grad_scale = 1,
5956  mx_float ignore_label = -1,
5957  bool multi_output = false,
5958  bool use_ignore = false,
5959  bool preserve_shape = false,
5961  bool out_grad = false) {
5962  static const char *SoftmaxNormalizationValues[] = {
5963  "batch",
5964  "null",
5965  "valid"
5966  };
5967  return Operator("Softmax")
5968  .SetParam("grad_scale", grad_scale)
5969  .SetParam("ignore_label", ignore_label)
5970  .SetParam("multi_output", multi_output)
5971  .SetParam("use_ignore", use_ignore)
5972  .SetParam("preserve_shape", preserve_shape)
5973  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
5974  .SetParam("out_grad", out_grad)
5975  .SetInput("data", data)
5976  .CreateSymbol(symbol_name);
5977 }
5978 
6054 inline Symbol SequenceReverse(const std::string& symbol_name,
6055  Symbol data,
6056  Symbol sequence_length,
6057  bool use_sequence_length = false) {
6058  return Operator("SequenceReverse")
6059  .SetParam("use_sequence_length", use_sequence_length)
6060  .SetInput("data", data)
6061  .SetInput("sequence_length", sequence_length)
6062  .CreateSymbol(symbol_name);
6063 }
6064 
6119 inline Symbol SequenceLast(const std::string& symbol_name,
6120  Symbol data,
6121  Symbol sequence_length,
6122  bool use_sequence_length = false) {
6123  return Operator("SequenceLast")
6124  .SetParam("use_sequence_length", use_sequence_length)
6125  .SetInput("data", data)
6126  .SetInput("sequence_length", sequence_length)
6127  .CreateSymbol(symbol_name);
6128 }
6129 
6178 inline Symbol Correlation(const std::string& symbol_name,
6179  Symbol data1,
6180  Symbol data2,
6181  uint32_t kernel_size = 1,
6182  uint32_t max_displacement = 1,
6183  uint32_t stride1 = 1,
6184  uint32_t stride2 = 1,
6185  uint32_t pad_size = 0,
6186  bool is_multiply = true) {
6187  return Operator("Correlation")
6188  .SetParam("kernel_size", kernel_size)
6189  .SetParam("max_displacement", max_displacement)
6190  .SetParam("stride1", stride1)
6191  .SetParam("stride2", stride2)
6192  .SetParam("pad_size", pad_size)
6193  .SetParam("is_multiply", is_multiply)
6194  .SetInput("data1", data1)
6195  .SetInput("data2", data2)
6196  .CreateSymbol(symbol_name);
6197 }
6198 
6214 inline Symbol SVMOutput(const std::string& symbol_name,
6215  Symbol data,
6216  Symbol label,
6217  mx_float margin = 1,
6218  mx_float regularization_coefficient = 1,
6219  bool use_linear = false) {
6220  return Operator("SVMOutput")
6221  .SetParam("margin", margin)
6222  .SetParam("regularization_coefficient", regularization_coefficient)
6223  .SetParam("use_linear", use_linear)
6224  .SetInput("data", data)
6225  .SetInput("label", label)
6226  .CreateSymbol(symbol_name);
6227 }
6228 
6232  kChannel = 0,
6233  kInstance = 1,
6234  kSpatial = 2
6235 };
6236 
6299 inline Symbol L2Normalization(const std::string& symbol_name,
6300  Symbol data,
6301  mx_float eps = 1e-10,
6303  static const char *L2NormalizationModeValues[] = {
6304  "channel",
6305  "instance",
6306  "spatial"
6307  };
6308  return Operator("L2Normalization")
6309  .SetParam("eps", eps)
6310  .SetParam("mode", L2NormalizationModeValues[int(mode)])
6311  .SetInput("data", data)
6312  .CreateSymbol(symbol_name);
6313 }
6314 
6342 inline Symbol LRN(const std::string& symbol_name,
6343  Symbol data,
6344  uint32_t nsize,
6345  mx_float alpha = 0.0001,
6346  mx_float beta = 0.75,
6347  mx_float knorm = 2) {
6348  return Operator("LRN")
6349  .SetParam("nsize", nsize)
6350  .SetParam("alpha", alpha)
6351  .SetParam("beta", beta)
6352  .SetParam("knorm", knorm)
6353  .SetInput("data", data)
6354  .CreateSymbol(symbol_name);
6355 }
6356 
6382 inline Symbol FullyConnected(const std::string& symbol_name,
6383  Symbol data,
6384  Symbol weight,
6385  Symbol bias,
6386  int num_hidden,
6387  bool no_bias = false) {
6388  return Operator("FullyConnected")
6389  .SetParam("num_hidden", num_hidden)
6390  .SetParam("no_bias", no_bias)
6391  .SetInput("data", data)
6392  .SetInput("weight", weight)
6393  .SetInput("bias", bias)
6394  .CreateSymbol(symbol_name);
6395 }
6396 
6474 inline Symbol SequenceMask(const std::string& symbol_name,
6475  Symbol data,
6476  Symbol sequence_length,
6477  bool use_sequence_length = false,
6478  mx_float value = 0) {
6479  return Operator("SequenceMask")
6480  .SetParam("use_sequence_length", use_sequence_length)
6481  .SetParam("value", value)
6482  .SetInput("data", data)
6483  .SetInput("sequence_length", sequence_length)
6484  .CreateSymbol(symbol_name);
6485 }
6486 
6491  kAffine = 0,
6492  kWarp = 1
6493 };
6494 
6505 inline Symbol GridGenerator(const std::string& symbol_name,
6506  Symbol data,
6507  GridGeneratorTransformType transform_type,
6508  Shape target_shape = Shape(0,0)) {
6509  static const char *GridGeneratorTransformTypeValues[] = {
6510  "affine",
6511  "warp"
6512  };
6513  return Operator("GridGenerator")
6514  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
6515  .SetParam("target_shape", target_shape)
6516  .SetInput("data", data)
6517  .CreateSymbol(symbol_name);
6518 }
6519 
6523  kAvg = 0,
6524  kMax = 1,
6525  kSum = 2
6526 };
6527 
6531  kFull = 0,
6532  kValid = 1
6533 };
6534 
6586 inline Symbol Pooling_v1(const std::string& symbol_name,
6587  Symbol data,
6588  Shape kernel,
6589  Pooling_v1PoolType pool_type,
6590  bool global_pool = false,
6592  Shape stride = Shape(),
6593  Shape pad = Shape()) {
6594  static const char *Pooling_v1PoolTypeValues[] = {
6595  "avg",
6596  "max",
6597  "sum"
6598  };
6599  static const char *Pooling_v1PoolingConventionValues[] = {
6600  "full",
6601  "valid"
6602  };
6603  return Operator("Pooling_v1")
6604  .SetParam("kernel", kernel)
6605  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
6606  .SetParam("global_pool", global_pool)
6607  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
6608  .SetParam("stride", stride)
6609  .SetParam("pad", pad)
6610  .SetInput("data", data)
6611  .CreateSymbol(symbol_name);
6612 }
6613 
6617  kNone = 0,
6618  kFastest = 1,
6619  kLimited_workspace = 2,
6620  kOff = 3
6621 };
6622 
6626 enum class ConvolutionLayout {
6627  kNone = 0,
6628  kNCDHW = 1,
6629  kNCHW = 2,
6630  kNCW = 3,
6631  kNDHWC = 4,
6632  kNHWC = 5
6633 };
6634 
6729 inline Symbol Convolution(const std::string& symbol_name,
6730  Symbol data,
6731  Symbol weight,
6732  Symbol bias,
6733  Shape kernel,
6734  uint32_t num_filter,
6735  Shape stride = Shape(),
6736  Shape dilate = Shape(),
6737  Shape pad = Shape(),
6738  uint32_t num_group = 1,
6739  uint64_t workspace = 1024,
6740  bool no_bias = false,
6742  bool cudnn_off = false,
6744  static const char *ConvolutionCudnnTuneValues[] = {
6745  "None",
6746  "fastest",
6747  "limited_workspace",
6748  "off"
6749  };
6750  static const char *ConvolutionLayoutValues[] = {
6751  "None",
6752  "NCDHW",
6753  "NCHW",
6754  "NCW",
6755  "NDHWC",
6756  "NHWC"
6757  };
6758  return Operator("Convolution")
6759  .SetParam("kernel", kernel)
6760  .SetParam("num_filter", num_filter)
6761  .SetParam("stride", stride)
6762  .SetParam("dilate", dilate)
6763  .SetParam("pad", pad)
6764  .SetParam("num_group", num_group)
6765  .SetParam("workspace", workspace)
6766  .SetParam("no_bias", no_bias)
6767  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
6768  .SetParam("cudnn_off", cudnn_off)
6769  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
6770  .SetInput("data", data)
6771  .SetInput("weight", weight)
6772  .SetInput("bias", bias)
6773  .CreateSymbol(symbol_name);
6774 }
6775 
6856 inline Symbol BilinearSampler(const std::string& symbol_name,
6857  Symbol data,
6858  Symbol grid) {
6859  return Operator("BilinearSampler")
6860  .SetInput("data", data)
6861  .SetInput("grid", grid)
6862  .CreateSymbol(symbol_name);
6863 }
6864 
6867 enum class PoolingPoolType {
6868  kAvg = 0,
6869  kMax = 1,
6870  kSum = 2
6871 };
6872 
6876  kFull = 0,
6877  kValid = 1
6878 };
6879 
6933 inline Symbol Pooling(const std::string& symbol_name,
6934  Symbol data,
6935  Shape kernel,
6936  PoolingPoolType pool_type,
6937  bool global_pool = false,
6938  bool cudnn_off = false,
6940  Shape stride = Shape(),
6941  Shape pad = Shape()) {
6942  static const char *PoolingPoolTypeValues[] = {
6943  "avg",
6944  "max",
6945  "sum"
6946  };
6947  static const char *PoolingPoolingConventionValues[] = {
6948  "full",
6949  "valid"
6950  };
6951  return Operator("Pooling")
6952  .SetParam("kernel", kernel)
6953  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
6954  .SetParam("global_pool", global_pool)
6955  .SetParam("cudnn_off", cudnn_off)
6956  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
6957  .SetParam("stride", stride)
6958  .SetParam("pad", pad)
6959  .SetInput("data", data)
6960  .CreateSymbol(symbol_name);
6961 }
6962 
6965 enum class DropoutMode {
6966  kAlways = 0,
6967  kTraining = 1
6968 };
6969 
7009 inline Symbol Dropout(const std::string& symbol_name,
7010  Symbol data,
7011  mx_float p = 0.5,
7013  static const char *DropoutModeValues[] = {
7014  "always",
7015  "training"
7016  };
7017  return Operator("Dropout")
7018  .SetParam("p", p)
7019  .SetParam("mode", DropoutModeValues[int(mode)])
7020  .SetInput("data", data)
7021  .CreateSymbol(symbol_name);
7022 }
7023 
7026 enum class ActivationActType {
7027  kRelu = 0,
7028  kSigmoid = 1,
7029  kSoftrelu = 2,
7030  kTanh = 3
7031 };
7032 
7051 inline Symbol Activation(const std::string& symbol_name,
7052  Symbol data,
7053  ActivationActType act_type) {
7054  static const char *ActivationActTypeValues[] = {
7055  "relu",
7056  "sigmoid",
7057  "softrelu",
7058  "tanh"
7059  };
7060  return Operator("Activation")
7061  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
7062  .SetInput("data", data)
7063  .CreateSymbol(symbol_name);
7064 }
7065 
7122 inline Symbol ROIPooling(const std::string& symbol_name,
7123  Symbol data,
7124  Symbol rois,
7125  Shape pooled_size,
7126  mx_float spatial_scale) {
7127  return Operator("ROIPooling")
7128  .SetParam("pooled_size", pooled_size)
7129  .SetParam("spatial_scale", spatial_scale)
7130  .SetInput("data", data)
7131  .SetInput("rois", rois)
7132  .CreateSymbol(symbol_name);
7133 }
7134 
7159 inline Symbol LinearRegressionOutput(const std::string& symbol_name,
7160  Symbol data,
7161  Symbol label,
7162  mx_float grad_scale = 1) {
7163  return Operator("LinearRegressionOutput")
7164  .SetParam("grad_scale", grad_scale)
7165  .SetInput("data", data)
7166  .SetInput("label", label)
7167  .CreateSymbol(symbol_name);
7168 }
7169 
7195 inline Symbol MAERegressionOutput(const std::string& symbol_name,
7196  Symbol data,
7197  Symbol label,
7198  mx_float grad_scale = 1) {
7199  return Operator("MAERegressionOutput")
7200  .SetParam("grad_scale", grad_scale)
7201  .SetInput("data", data)
7202  .SetInput("label", label)
7203  .CreateSymbol(symbol_name);
7204 }
7205 
7231 inline Symbol LogisticRegressionOutput(const std::string& symbol_name,
7232  Symbol data,
7233  Symbol label,
7234  mx_float grad_scale = 1) {
7235  return Operator("LogisticRegressionOutput")
7236  .SetParam("grad_scale", grad_scale)
7237  .SetInput("data", data)
7238  .SetInput("label", label)
7239  .CreateSymbol(symbol_name);
7240 }
7241 
7246  kChannel = 0,
7247  kInstance = 1
7248 };
7249 
7283 inline Symbol SoftmaxActivation(const std::string& symbol_name,
7284  Symbol data,
7286  static const char *SoftmaxActivationModeValues[] = {
7287  "channel",
7288  "instance"
7289  };
7290  return Operator("SoftmaxActivation")
7291  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
7292  .SetInput("data", data)
7293  .CreateSymbol(symbol_name);
7294 }
7295 
7301  kBatch = 0,
7302  kNull = 1,
7303  kValid = 2
7304 };
7305 
7340 inline Symbol MakeLoss(const std::string& symbol_name,
7341  Symbol data,
7342  mx_float grad_scale = 1,
7343  mx_float valid_thresh = 0,
7345  static const char *MakeLossNormalizationValues[] = {
7346  "batch",
7347  "null",
7348  "valid"
7349  };
7350  return Operator("MakeLoss")
7351  .SetParam("grad_scale", grad_scale)
7352  .SetParam("valid_thresh", valid_thresh)
7353  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
7354  .SetInput("data", data)
7355  .CreateSymbol(symbol_name);
7356 }
7357 
7366 inline Symbol choose_element_0index(const std::string& symbol_name,
7367  Symbol lhs,
7368  Symbol rhs) {
7369  return Operator("choose_element_0index")
7370  .SetInput("lhs", lhs)
7371  .SetInput("rhs", rhs)
7372  .CreateSymbol(symbol_name);
7373 }
7374 
7384 inline Symbol fill_element_0index(const std::string& symbol_name,
7385  Symbol lhs,
7386  Symbol mhs,
7387  Symbol rhs) {
7388  return Operator("fill_element_0index")
7389  .SetInput("lhs", lhs)
7390  .SetInput("mhs", mhs)
7391  .SetInput("rhs", rhs)
7392  .CreateSymbol(symbol_name);
7393 }
7394 
7423 inline Symbol softmax(Symbol data,
7424  int axis = -1) {
7425  return Operator("softmax")
7426  .SetParam("axis", axis)
7427  .SetInput("data", data)
7428  .CreateSymbol();
7429 }
7430 
7453  int axis = -1) {
7454  return Operator("log_softmax")
7455  .SetParam("axis", axis)
7456  .SetInput("data", data)
7457  .CreateSymbol();
7458 }
7459 
7498  Shape shape = Shape(),
7499  bool get_prob = false,
7501  static const char *Sample_multinomialDtypeValues[] = {
7502  "int32"
7503  };
7504  return Operator("sample_multinomial")
7505  .SetParam("shape", shape)
7506  .SetParam("get_prob", get_prob)
7507  .SetParam("dtype", Sample_multinomialDtypeValues[int(dtype)])
7508  .SetInput("data", data)
7509  .CreateSymbol();
7510 }
7511 
7548  Symbol high,
7549  Shape shape = Shape(),
7551  static const char *Sample_uniformDtypeValues[] = {
7552  "None",
7553  "float16",
7554  "float32",
7555  "float64"
7556  };
7557  return Operator("sample_uniform")
7558  .SetParam("shape", shape)
7559  .SetParam("dtype", Sample_uniformDtypeValues[int(dtype)])
7560  .SetInput("low", low)
7561  .SetInput("high", high)
7562  .CreateSymbol();
7563 }
7564 
7601  Symbol sigma,
7602  Shape shape = Shape(),
7604  static const char *Sample_normalDtypeValues[] = {
7605  "None",
7606  "float16",
7607  "float32",
7608  "float64"
7609  };
7610  return Operator("sample_normal")
7611  .SetParam("shape", shape)
7612  .SetParam("dtype", Sample_normalDtypeValues[int(dtype)])
7613  .SetInput("mu", mu)
7614  .SetInput("sigma", sigma)
7615  .CreateSymbol();
7616 }
7617 
7654  Symbol beta,
7655  Shape shape = Shape(),
7657  static const char *Sample_gammaDtypeValues[] = {
7658  "None",
7659  "float16",
7660  "float32",
7661  "float64"
7662  };
7663  return Operator("sample_gamma")
7664  .SetParam("shape", shape)
7665  .SetParam("dtype", Sample_gammaDtypeValues[int(dtype)])
7666  .SetInput("alpha", alpha)
7667  .SetInput("beta", beta)
7668  .CreateSymbol();
7669 }
7670 
7705  Shape shape = Shape(),
7707  static const char *Sample_exponentialDtypeValues[] = {
7708  "None",
7709  "float16",
7710  "float32",
7711  "float64"
7712  };
7713  return Operator("sample_exponential")
7714  .SetParam("shape", shape)
7715  .SetParam("dtype", Sample_exponentialDtypeValues[int(dtype)])
7716  .SetInput("lam", lam)
7717  .CreateSymbol();
7718 }
7719 
7756  Shape shape = Shape(),
7758  static const char *Sample_poissonDtypeValues[] = {
7759  "None",
7760  "float16",
7761  "float32",
7762  "float64"
7763  };
7764  return Operator("sample_poisson")
7765  .SetParam("shape", shape)
7766  .SetParam("dtype", Sample_poissonDtypeValues[int(dtype)])
7767  .SetInput("lam", lam)
7768  .CreateSymbol();
7769 }
7770 
7809  Symbol p,
7810  Shape shape = Shape(),
7812  static const char *Sample_negative_binomialDtypeValues[] = {
7813  "None",
7814  "float16",
7815  "float32",
7816  "float64"
7817  };
7818  return Operator("sample_negative_binomial")
7819  .SetParam("shape", shape)
7820  .SetParam("dtype", Sample_negative_binomialDtypeValues[int(dtype)])
7821  .SetInput("k", k)
7822  .SetInput("p", p)
7823  .CreateSymbol();
7824 }
7825 
7864  Symbol alpha,
7865  Shape shape = Shape(),
7867  static const char *Sample_generalized_negative_binomialDtypeValues[] = {
7868  "None",
7869  "float16",
7870  "float32",
7871  "float64"
7872  };
7873  return Operator("sample_generalized_negative_binomial")
7874  .SetParam("shape", shape)
7875  .SetParam("dtype", Sample_generalized_negative_binomialDtypeValues[int(dtype)])
7876  .SetInput("mu", mu)
7877  .SetInput("alpha", alpha)
7878  .CreateSymbol();
7879 }
7880 
7905  mx_float high = 1,
7906  Shape shape = Shape(),
7907  const std::string& ctx = "",
7909  static const char *Random_uniformDtypeValues[] = {
7910  "None",
7911  "float16",
7912  "float32",
7913  "float64"
7914  };
7915  return Operator("random_uniform")
7916  .SetParam("low", low)
7917  .SetParam("high", high)
7918  .SetParam("shape", shape)
7919  .SetParam("dtype", Random_uniformDtypeValues[int(dtype)])
7920  .CreateSymbol();
7921 }
7922 
7945  mx_float scale = 1,
7946  Shape shape = Shape(),
7947  const std::string& ctx = "",
7949  static const char *Random_normalDtypeValues[] = {
7950  "None",
7951  "float16",
7952  "float32",
7953  "float64"
7954  };
7955  return Operator("random_normal")
7956  .SetParam("loc", loc)
7957  .SetParam("scale", scale)
7958  .SetParam("shape", shape)
7959  .SetParam("dtype", Random_normalDtypeValues[int(dtype)])
7960  .CreateSymbol();
7961 }
7962 
7982 inline Symbol random_gamma(mx_float alpha = 1,
7983  mx_float beta = 1,
7984  Shape shape = Shape(),
7985  const std::string& ctx = "",
7987  static const char *Random_gammaDtypeValues[] = {
7988  "None",
7989  "float16",
7990  "float32",
7991  "float64"
7992  };
7993  return Operator("random_gamma")
7994  .SetParam("alpha", alpha)
7995  .SetParam("beta", beta)
7996  .SetParam("shape", shape)
7997  .SetParam("dtype", Random_gammaDtypeValues[int(dtype)])
7998  .CreateSymbol();
7999 }
8000 
8020  Shape shape = Shape(),
8021  const std::string& ctx = "",
8023  static const char *Random_exponentialDtypeValues[] = {
8024  "None",
8025  "float16",
8026  "float32",
8027  "float64"
8028  };
8029  return Operator("random_exponential")
8030  .SetParam("lam", lam)
8031  .SetParam("shape", shape)
8032  .SetParam("dtype", Random_exponentialDtypeValues[int(dtype)])
8033  .CreateSymbol();
8034 }
8035 
8056  Shape shape = Shape(),
8057  const std::string& ctx = "",
8059  static const char *Random_poissonDtypeValues[] = {
8060  "None",
8061  "float16",
8062  "float32",
8063  "float64"
8064  };
8065  return Operator("random_poisson")
8066  .SetParam("lam", lam)
8067  .SetParam("shape", shape)
8068  .SetParam("dtype", Random_poissonDtypeValues[int(dtype)])
8069  .CreateSymbol();
8070 }
8071 
8094  mx_float p = 1,
8095  Shape shape = Shape(),
8096  const std::string& ctx = "",
8098  static const char *Random_negative_binomialDtypeValues[] = {
8099  "None",
8100  "float16",
8101  "float32",
8102  "float64"
8103  };
8104  return Operator("random_negative_binomial")
8105  .SetParam("k", k)
8106  .SetParam("p", p)
8107  .SetParam("shape", shape)
8108  .SetParam("dtype", Random_negative_binomialDtypeValues[int(dtype)])
8109  .CreateSymbol();
8110 }
8111 
8135  mx_float alpha = 1,
8136  Shape shape = Shape(),
8137  const std::string& ctx = "",
8139  static const char *Random_generalized_negative_binomialDtypeValues[] = {
8140  "None",
8141  "float16",
8142  "float32",
8143  "float64"
8144  };
8145  return Operator("random_generalized_negative_binomial")
8146  .SetParam("mu", mu)
8147  .SetParam("alpha", alpha)
8148  .SetParam("shape", shape)
8149  .SetParam("dtype", Random_generalized_negative_binomialDtypeValues[int(dtype)])
8150  .CreateSymbol();
8151 }
8152 
8175  Symbol rhs) {
8176  return Operator("broadcast_power")
8177  .SetInput("lhs", lhs)
8178  .SetInput("rhs", rhs)
8179  .CreateSymbol();
8180 }
8181 
8206  Symbol rhs) {
8207  return Operator("broadcast_maximum")
8208  .SetInput("lhs", lhs)
8209  .SetInput("rhs", rhs)
8210  .CreateSymbol();
8211 }
8212 
8237  Symbol rhs) {
8238  return Operator("broadcast_minimum")
8239  .SetInput("lhs", lhs)
8240  .SetInput("rhs", rhs)
8241  .CreateSymbol();
8242 }
8243 
8274  Symbol rhs) {
8275  return Operator("broadcast_hypot")
8276  .SetInput("lhs", lhs)
8277  .SetInput("rhs", rhs)
8278  .CreateSymbol();
8279 }
8280 
8354 inline Symbol Reshape(Symbol data,
8355  Shape shape = Shape(),
8356  bool reverse = false,
8357  Shape target_shape = Shape(),
8358  bool keep_highest = false) {
8359  return Operator("Reshape")
8360  .SetParam("shape", shape)
8361  .SetParam("reverse", reverse)
8362  .SetParam("target_shape", target_shape)
8363  .SetParam("keep_highest", keep_highest)
8364  .SetInput("data", data)
8365  .CreateSymbol();
8366 }
8367 
8397 inline Symbol Flatten(Symbol data) {
8398  return Operator("Flatten")
8399  .SetInput("data", data)
8400  .CreateSymbol();
8401 }
8402 
8439  Shape axes = Shape()) {
8440  return Operator("transpose")
8441  .SetParam("axes", axes)
8442  .SetInput("data", data)
8443  .CreateSymbol();
8444 }
8445 
8461  int axis) {
8462  return Operator("expand_dims")
8463  .SetParam("axis", axis)
8464  .SetInput("data", data)
8465  .CreateSymbol();
8466 }
8467 
8500 inline Symbol slice(Symbol data,
8501  Shape begin,
8502  Shape end) {
8503  return Operator("slice")
8504  .SetParam("begin", begin)
8505  .SetParam("end", end)
8506  .SetInput("data", data)
8507  .CreateSymbol();
8508 }
8509 
8542  int axis,
8543  int begin,
8544  dmlc::optional<int> end) {
8545  return Operator("slice_axis")
8546  .SetParam("axis", axis)
8547  .SetParam("begin", begin)
8548  .SetParam("end", end)
8549  .SetInput("data", data)
8550  .CreateSymbol();
8551 }
8552 
8583 inline Symbol dot(Symbol lhs,
8584  Symbol rhs,
8585  bool transpose_a = false,
8586  bool transpose_b = false) {
8587  return Operator("dot")
8588  .SetParam("transpose_a", transpose_a)
8589  .SetParam("transpose_b", transpose_b)
8590  .SetInput("lhs", lhs)
8591  .SetInput("rhs", rhs)
8592  .CreateSymbol();
8593 }
8594 
8617  Symbol rhs,
8618  bool transpose_a = false,
8619  bool transpose_b = false) {
8620  return Operator("batch_dot")
8621  .SetParam("transpose_a", transpose_a)
8622  .SetParam("transpose_b", transpose_b)
8623  .SetInput("lhs", lhs)
8624  .SetInput("rhs", rhs)
8625  .CreateSymbol();
8626 }
8627 
8650 inline Symbol clip(Symbol data,
8651  mx_float a_min,
8652  mx_float a_max) {
8653  return Operator("clip")
8654  .SetParam("a_min", a_min)
8655  .SetParam("a_max", a_max)
8656  .SetInput("data", data)
8657  .CreateSymbol();
8658 }
8659 
8693 inline Symbol repeat(Symbol data,
8694  int repeats,
8695  dmlc::optional<int> axis = dmlc::optional<int>()) {
8696  return Operator("repeat")
8697  .SetParam("repeats", repeats)
8698  .SetParam("axis", axis)
8699  .SetInput("data", data)
8700  .CreateSymbol();
8701 }
8702 
8747 inline Symbol tile(Symbol data,
8748  Shape reps) {
8749  return Operator("tile")
8750  .SetParam("reps", reps)
8751  .SetInput("data", data)
8752  .CreateSymbol();
8753 }
8754 
8777 inline Symbol reverse(Symbol data,
8778  Shape axis) {
8779  return Operator("reverse")
8780  .SetParam("axis", axis)
8781  .SetInput("data", data)
8782  .CreateSymbol();
8783 }
8784 
8807 inline Symbol stack(const std::vector<Symbol>& data,
8808  int num_args,
8809  int axis = 0) {
8810  return Operator("stack")
8811  .SetParam("num_args", num_args)
8812  .SetParam("axis", axis)
8813 (data)
8814  .CreateSymbol();
8815 }
8816 
8833 inline Symbol zeros_like(Symbol data) {
8834  return Operator("zeros_like")
8835  .SetInput("data", data)
8836  .CreateSymbol();
8837 }
8838 
8855 inline Symbol ones_like(Symbol data) {
8856  return Operator("ones_like")
8857  .SetInput("data", data)
8858  .CreateSymbol();
8859 }
8860 
8888  Symbol rhs) {
8889  return Operator("broadcast_add")
8890  .SetInput("lhs", lhs)
8891  .SetInput("rhs", rhs)
8892  .CreateSymbol();
8893 }
8894 
8922  Symbol rhs) {
8923  return Operator("broadcast_sub")
8924  .SetInput("lhs", lhs)
8925  .SetInput("rhs", rhs)
8926  .CreateSymbol();
8927 }
8928 
8951  Symbol rhs) {
8952  return Operator("broadcast_mul")
8953  .SetInput("lhs", lhs)
8954  .SetInput("rhs", rhs)
8955  .CreateSymbol();
8956 }
8957 
8980  Symbol rhs) {
8981  return Operator("broadcast_div")
8982  .SetInput("lhs", lhs)
8983  .SetInput("rhs", rhs)
8984  .CreateSymbol();
8985 }
8986 
9009  Symbol rhs) {
9010  return Operator("broadcast_mod")
9011  .SetInput("lhs", lhs)
9012  .SetInput("rhs", rhs)
9013  .CreateSymbol();
9014 }
9015 
9029 inline Symbol add_n(const std::vector<Symbol>& args) {
9030  return Operator("add_n")
9031 (args)
9032  .CreateSymbol();
9033 }
9034 
9065 inline Symbol argmax(Symbol data,
9066  dmlc::optional<int> axis = dmlc::optional<int>(),
9067  bool keepdims = false) {
9068  return Operator("argmax")
9069  .SetParam("axis", axis)
9070  .SetParam("keepdims", keepdims)
9071  .SetInput("data", data)
9072  .CreateSymbol();
9073 }
9074 
9105 inline Symbol argmin(Symbol data,
9106  dmlc::optional<int> axis = dmlc::optional<int>(),
9107  bool keepdims = false) {
9108  return Operator("argmin")
9109  .SetParam("axis", axis)
9110  .SetParam("keepdims", keepdims)
9111  .SetInput("data", data)
9112  .CreateSymbol();
9113 }
9114 
9137  return Operator("argmax_channel")
9138  .SetInput("data", data)
9139  .CreateSymbol();
9140 }
9141 
9186 inline Symbol pick(Symbol data,
9187  Symbol index,
9188  dmlc::optional<int> axis = dmlc::optional<int>(),
9189  bool keepdims = false) {
9190  return Operator("pick")
9191  .SetParam("axis", axis)
9192  .SetParam("keepdims", keepdims)
9193  .SetInput("data", data)
9194  .SetInput("index", index)
9195  .CreateSymbol();
9196 }
9197 
9241 inline Symbol sum(Symbol data,
9242  Shape axis = Shape(),
9243  bool keepdims = false,
9244  bool exclude = false) {
9245  return Operator("sum")
9246  .SetParam("axis", axis)
9247  .SetParam("keepdims", keepdims)
9248  .SetParam("exclude", exclude)
9249  .SetInput("data", data)
9250  .CreateSymbol();
9251 }
9252 
9276 inline Symbol mean(Symbol data,
9277  Shape axis = Shape(),
9278  bool keepdims = false,
9279  bool exclude = false) {
9280  return Operator("mean")
9281  .SetParam("axis", axis)
9282  .SetParam("keepdims", keepdims)
9283  .SetParam("exclude", exclude)
9284  .SetInput("data", data)
9285  .CreateSymbol();
9286 }
9287 
9311 inline Symbol prod(Symbol data,
9312  Shape axis = Shape(),
9313  bool keepdims = false,
9314  bool exclude = false) {
9315  return Operator("prod")
9316  .SetParam("axis", axis)
9317  .SetParam("keepdims", keepdims)
9318  .SetParam("exclude", exclude)
9319  .SetInput("data", data)
9320  .CreateSymbol();
9321 }
9322 
9348 inline Symbol nansum(Symbol data,
9349  Shape axis = Shape(),
9350  bool keepdims = false,
9351  bool exclude = false) {
9352  return Operator("nansum")
9353  .SetParam("axis", axis)
9354  .SetParam("keepdims", keepdims)
9355  .SetParam("exclude", exclude)
9356  .SetInput("data", data)
9357  .CreateSymbol();
9358 }
9359 
9385 inline Symbol nanprod(Symbol data,
9386  Shape axis = Shape(),
9387  bool keepdims = false,
9388  bool exclude = false) {
9389  return Operator("nanprod")
9390  .SetParam("axis", axis)
9391  .SetParam("keepdims", keepdims)
9392  .SetParam("exclude", exclude)
9393  .SetInput("data", data)
9394  .CreateSymbol();
9395 }
9396 
9420 inline Symbol max(Symbol data,
9421  Shape axis = Shape(),
9422  bool keepdims = false,
9423  bool exclude = false) {
9424  return Operator("max")
9425  .SetParam("axis", axis)
9426  .SetParam("keepdims", keepdims)
9427  .SetParam("exclude", exclude)
9428  .SetInput("data", data)
9429  .CreateSymbol();
9430 }
9431 
9455 inline Symbol min(Symbol data,
9456  Shape axis = Shape(),
9457  bool keepdims = false,
9458  bool exclude = false) {
9459  return Operator("min")
9460  .SetParam("axis", axis)
9461  .SetParam("keepdims", keepdims)
9462  .SetParam("exclude", exclude)
9463  .SetInput("data", data)
9464  .CreateSymbol();
9465 }
9466 
9496  Shape axis = Shape(),
9497  Shape size = Shape()) {
9498  return Operator("broadcast_axis")
9499  .SetParam("axis", axis)
9500  .SetParam("size", size)
9501  .SetInput("data", data)
9502  .CreateSymbol();
9503 }
9504 
9533  Shape shape = Shape()) {
9534  return Operator("broadcast_to")
9535  .SetParam("shape", shape)
9536  .SetInput("data", data)
9537  .CreateSymbol();
9538 }
9539 
9556 inline Symbol norm(Symbol data) {
9557  return Operator("norm")
9558  .SetInput("data", data)
9559  .CreateSymbol();
9560 }
9561 
9602 inline Symbol topk(Symbol data,
9603  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9604  int k = 1,
9605  TopkRetTyp ret_typ = TopkRetTyp::kIndices,
9606  bool is_ascend = false) {
9607  static const char *TopkRetTypValues[] = {
9608  "both",
9609  "indices",
9610  "mask",
9611  "value"
9612  };
9613  return Operator("topk")
9614  .SetParam("axis", axis)
9615  .SetParam("k", k)
9616  .SetParam("ret_typ", TopkRetTypValues[int(ret_typ)])
9617  .SetParam("is_ascend", is_ascend)
9618  .SetInput("data", data)
9619  .CreateSymbol();
9620 }
9621 
9653 inline Symbol sort(Symbol data,
9654  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9655  bool is_ascend = true) {
9656  return Operator("sort")
9657  .SetParam("axis", axis)
9658  .SetParam("is_ascend", is_ascend)
9659  .SetInput("data", data)
9660  .CreateSymbol();
9661 }
9662 
9692 inline Symbol argsort(Symbol data,
9693  dmlc::optional<int> axis = dmlc::optional<int>(-1),
9694  bool is_ascend = true) {
9695  return Operator("argsort")
9696  .SetParam("axis", axis)
9697  .SetParam("is_ascend", is_ascend)
9698  .SetInput("data", data)
9699  .CreateSymbol();
9700 }
9701 
9709  Symbol rhs) {
9710  return Operator("elemwise_add")
9711  .SetInput("lhs", lhs)
9712  .SetInput("rhs", rhs)
9713  .CreateSymbol();
9714 }
9715 
9728 inline Symbol relu(Symbol data) {
9729  return Operator("relu")
9730  .SetInput("data", data)
9731  .CreateSymbol();
9732 }
9733 
9746 inline Symbol sigmoid(Symbol data) {
9747  return Operator("sigmoid")
9748  .SetInput("data", data)
9749  .CreateSymbol();
9750 }
9751 
9784 inline Symbol BlockGrad(Symbol data) {
9785  return Operator("BlockGrad")
9786  .SetInput("data", data)
9787  .CreateSymbol();
9788 }
9789 
9799 inline Symbol make_loss(Symbol data) {
9800  return Operator("make_loss")
9801  .SetInput("data", data)
9802  .CreateSymbol();
9803 }
9804 
9823 inline Symbol Cast(Symbol data,
9824  CastDtype dtype) {
9825  static const char *CastDtypeValues[] = {
9826  "float16",
9827  "float32",
9828  "float64",
9829  "int32",
9830  "uint8"
9831  };
9832  return Operator("Cast")
9833  .SetParam("dtype", CastDtypeValues[int(dtype)])
9834  .SetInput("data", data)
9835  .CreateSymbol();
9836 }
9837 
9845 inline Symbol negative(Symbol data) {
9846  return Operator("negative")
9847  .SetInput("data", data)
9848  .CreateSymbol();
9849 }
9850 
9866 inline Symbol reciprocal(Symbol data) {
9867  return Operator("reciprocal")
9868  .SetInput("data", data)
9869  .CreateSymbol();
9870 }
9871 
9885 inline Symbol abs(Symbol data) {
9886  return Operator("abs")
9887  .SetInput("data", data)
9888  .CreateSymbol();
9889 }
9890 
9904 inline Symbol sign(Symbol data) {
9905  return Operator("sign")
9906  .SetInput("data", data)
9907  .CreateSymbol();
9908 }
9909 
9923 inline Symbol round(Symbol data) {
9924  return Operator("round")
9925  .SetInput("data", data)
9926  .CreateSymbol();
9927 }
9928 
9946 inline Symbol rint(Symbol data) {
9947  return Operator("rint")
9948  .SetInput("data", data)
9949  .CreateSymbol();
9950 }
9951 
9967 inline Symbol ceil(Symbol data) {
9968  return Operator("ceil")
9969  .SetInput("data", data)
9970  .CreateSymbol();
9971 }
9972 
9988 inline Symbol floor(Symbol data) {
9989  return Operator("floor")
9990  .SetInput("data", data)
9991  .CreateSymbol();
9992 }
9993 
10010 inline Symbol trunc(Symbol data) {
10011  return Operator("trunc")
10012  .SetInput("data", data)
10013  .CreateSymbol();
10014 }
10015 
10029 inline Symbol fix(Symbol data) {
10030  return Operator("fix")
10031  .SetInput("data", data)
10032  .CreateSymbol();
10033 }
10034 
10051 inline Symbol square(Symbol data) {
10052  return Operator("square")
10053  .SetInput("data", data)
10054  .CreateSymbol();
10055 }
10056 
10073 inline Symbol sqrt(Symbol data) {
10074  return Operator("sqrt")
10075  .SetInput("data", data)
10076  .CreateSymbol();
10077 }
10078 
10095 inline Symbol rsqrt(Symbol data) {
10096  return Operator("rsqrt")
10097  .SetInput("data", data)
10098  .CreateSymbol();
10099 }
10100 
10117 inline Symbol exp(Symbol data) {
10118  return Operator("exp")
10119  .SetInput("data", data)
10120  .CreateSymbol();
10121 }
10122 
10134 inline Symbol log(Symbol data) {
10135  return Operator("log")
10136  .SetInput("data", data)
10137  .CreateSymbol();
10138 }
10139 
10151 inline Symbol log10(Symbol data) {
10152  return Operator("log10")
10153  .SetInput("data", data)
10154  .CreateSymbol();
10155 }
10156 
10168 inline Symbol log2(Symbol data) {
10169  return Operator("log2")
10170  .SetInput("data", data)
10171  .CreateSymbol();
10172 }
10173 
10188 inline Symbol sin(Symbol data) {
10189  return Operator("sin")
10190  .SetInput("data", data)
10191  .CreateSymbol();
10192 }
10193 
10206 inline Symbol log1p(Symbol data) {
10207  return Operator("log1p")
10208  .SetInput("data", data)
10209  .CreateSymbol();
10210 }
10211 
10223 inline Symbol expm1(Symbol data) {
10224  return Operator("expm1")
10225  .SetInput("data", data)
10226  .CreateSymbol();
10227 }
10228 
10243 inline Symbol cos(Symbol data) {
10244  return Operator("cos")
10245  .SetInput("data", data)
10246  .CreateSymbol();
10247 }
10248 
10263 inline Symbol tan(Symbol data) {
10264  return Operator("tan")
10265  .SetInput("data", data)
10266  .CreateSymbol();
10267 }
10268 
10284 inline Symbol arcsin(Symbol data) {
10285  return Operator("arcsin")
10286  .SetInput("data", data)
10287  .CreateSymbol();
10288 }
10289 
10305 inline Symbol arccos(Symbol data) {
10306  return Operator("arccos")
10307  .SetInput("data", data)
10308  .CreateSymbol();
10309 }
10310 
10325 inline Symbol arctan(Symbol data) {
10326  return Operator("arctan")
10327  .SetInput("data", data)
10328  .CreateSymbol();
10329 }
10330 
10343 inline Symbol degrees(Symbol data) {
10344  return Operator("degrees")
10345  .SetInput("data", data)
10346  .CreateSymbol();
10347 }
10348 
10361 inline Symbol radians(Symbol data) {
10362  return Operator("radians")
10363  .SetInput("data", data)
10364  .CreateSymbol();
10365 }
10366 
10379 inline Symbol sinh(Symbol data) {
10380  return Operator("sinh")
10381  .SetInput("data", data)
10382  .CreateSymbol();
10383 }
10384 
10397 inline Symbol cosh(Symbol data) {
10398  return Operator("cosh")
10399  .SetInput("data", data)
10400  .CreateSymbol();
10401 }
10402 
10415 inline Symbol tanh(Symbol data) {
10416  return Operator("tanh")
10417  .SetInput("data", data)
10418  .CreateSymbol();
10419 }
10420 
10429 inline Symbol arcsinh(Symbol data) {
10430  return Operator("arcsinh")
10431  .SetInput("data", data)
10432  .CreateSymbol();
10433 }
10434 
10443 inline Symbol arccosh(Symbol data) {
10444  return Operator("arccosh")
10445  .SetInput("data", data)
10446  .CreateSymbol();
10447 }
10448 
10457 inline Symbol arctanh(Symbol data) {
10458  return Operator("arctanh")
10459  .SetInput("data", data)
10460  .CreateSymbol();
10461 }
10462 
10470 inline Symbol gamma(Symbol data) {
10471  return Operator("gamma")
10472  .SetInput("data", data)
10473  .CreateSymbol();
10474 }
10475 
10483 inline Symbol gammaln(Symbol data) {
10484  return Operator("gammaln")
10485  .SetInput("data", data)
10486  .CreateSymbol();
10487 }
10488 
10540  Symbol weight,
10541  int input_dim,
10542  int output_dim,
10544  static const char *EmbeddingDtypeValues[] = {
10545  "float16",
10546  "float32",
10547  "float64",
10548  "int32",
10549  "uint8"
10550  };
10551  return Operator("Embedding")
10552  .SetParam("input_dim", input_dim)
10553  .SetParam("output_dim", output_dim)
10554  .SetParam("dtype", EmbeddingDtypeValues[int(dtype)])
10555  .SetInput("data", data)
10556  .SetInput("weight", weight)
10557  .CreateSymbol();
10558 }
10559 
10598 inline Symbol take(Symbol a,
10599  Symbol indices,
10600  int axis = 0,
10601  TakeMode mode = TakeMode::kClip) {
10602  static const char *TakeModeValues[] = {
10603  "clip",
10604  "raise",
10605  "wrap"
10606  };
10607  return Operator("take")
10608  .SetParam("axis", axis)
10609  .SetParam("mode", TakeModeValues[int(mode)])
10610  .SetInput("a", a)
10611  .SetInput("indices", indices)
10612  .CreateSymbol();
10613 }
10614 
10643  Symbol indices) {
10644  return Operator("batch_take")
10645  .SetInput("a", a)
10646  .SetInput("indices", indices)
10647  .CreateSymbol();
10648 }
10649 
10693 inline Symbol one_hot(Symbol indices,
10694  int depth,
10695  double on_value = 1,
10696  double off_value = 0,
10698  static const char *One_hotDtypeValues[] = {
10699  "float16",
10700  "float32",
10701  "float64",
10702  "int32",
10703  "uint8"
10704  };
10705  return Operator("one_hot")
10706  .SetParam("depth", depth)
10707  .SetParam("on_value", on_value)
10708  .SetParam("off_value", off_value)
10709  .SetParam("dtype", One_hotDtypeValues[int(dtype)])
10710  .SetInput("indices", indices)
10711  .CreateSymbol();
10712 }
10713 
10736  Symbol rhs) {
10737  return Operator("broadcast_equal")
10738  .SetInput("lhs", lhs)
10739  .SetInput("rhs", rhs)
10740  .CreateSymbol();
10741 }
10742 
10765  Symbol rhs) {
10766  return Operator("broadcast_not_equal")
10767  .SetInput("lhs", lhs)
10768  .SetInput("rhs", rhs)
10769  .CreateSymbol();
10770 }
10771 
10794  Symbol rhs) {
10795  return Operator("broadcast_greater")
10796  .SetInput("lhs", lhs)
10797  .SetInput("rhs", rhs)
10798  .CreateSymbol();
10799 }
10800 
10823  Symbol rhs) {
10824  return Operator("broadcast_greater_equal")
10825  .SetInput("lhs", lhs)
10826  .SetInput("rhs", rhs)
10827  .CreateSymbol();
10828 }
10829 
10852  Symbol rhs) {
10853  return Operator("broadcast_lesser")
10854  .SetInput("lhs", lhs)
10855  .SetInput("rhs", rhs)
10856  .CreateSymbol();
10857 }
10858 
10881  Symbol rhs) {
10882  return Operator("broadcast_lesser_equal")
10883  .SetInput("lhs", lhs)
10884  .SetInput("rhs", rhs)
10885  .CreateSymbol();
10886 }
10887 
10933  Symbol B,
10934  Symbol C,
10935  bool transpose_a = false,
10936  bool transpose_b = false,
10937  double alpha = 1,
10938  double beta = 1) {
10939  return Operator("linalg_gemm")
10940  .SetParam("transpose_a", transpose_a)
10941  .SetParam("transpose_b", transpose_b)
10942  .SetParam("alpha", alpha)
10943  .SetParam("beta", beta)
10944  .SetInput("A", A)
10945  .SetInput("B", B)
10946  .SetInput("C", C)
10947  .CreateSymbol();
10948 }
10949 
10991  Symbol B,
10992  bool transpose_a = false,
10993  bool transpose_b = false,
10994  double alpha = 1) {
10995  return Operator("linalg_gemm2")
10996  .SetParam("transpose_a", transpose_a)
10997  .SetParam("transpose_b", transpose_b)
10998  .SetParam("alpha", alpha)
10999  .SetInput("A", A)
11000  .SetInput("B", B)
11001  .CreateSymbol();
11002 }
11003 
11039  return Operator("linalg_potrf")
11040  .SetInput("A", A)
11041  .CreateSymbol();
11042 }
11043 
11079  return Operator("linalg_potri")
11080  .SetInput("A", A)
11081  .CreateSymbol();
11082 }
11083 
11131  Symbol B,
11132  bool transpose = false,
11133  bool rightside = false,
11134  double alpha = 1) {
11135  return Operator("linalg_trmm")
11136  .SetParam("transpose", transpose)
11137  .SetParam("rightside", rightside)
11138  .SetParam("alpha", alpha)
11139  .SetInput("A", A)
11140  .SetInput("B", B)
11141  .CreateSymbol();
11142 }
11143 
11191  Symbol B,
11192  bool transpose = false,
11193  bool rightside = false,
11194  double alpha = 1) {
11195  return Operator("linalg_trsm")
11196  .SetParam("transpose", transpose)
11197  .SetParam("rightside", rightside)
11198  .SetParam("alpha", alpha)
11199  .SetInput("A", A)
11200  .SetInput("B", B)
11201  .CreateSymbol();
11202 }
11203 
11232  return Operator("linalg_sumlogdiag")
11233  .SetInput("A", A)
11234  .CreateSymbol();
11235 }
11236 
11252 inline Symbol where(Symbol condition,
11253  Symbol x,
11254  Symbol y) {
11255  return Operator("where")
11256  .SetInput("condition", condition)
11257  .SetInput("x", x)
11258  .SetInput("y", y)
11259  .CreateSymbol();
11260 }
11261 
11287  mx_float scalar) {
11288  return Operator("smooth_l1")
11289  .SetParam("scalar", scalar)
11290  .SetInput("data", data)
11291  .CreateSymbol();
11292 }
11293 
11308 inline Symbol Custom(const std::vector<Symbol>& data,
11309  const std::string& op_type) {
11310  return Operator("Custom")
11311 (data)
11312  .CreateSymbol();
11313 }
11314 
11342 inline Symbol SwapAxis(Symbol data,
11343  uint32_t dim1 = 0,
11344  uint32_t dim2 = 0) {
11345  return Operator("SwapAxis")
11346  .SetParam("dim1", dim1)
11347  .SetParam("dim2", dim2)
11348  .SetInput("data", data)
11349  .CreateSymbol();
11350 }
11351 
11379  mx_float slope = 0.25,
11380  mx_float lower_bound = 0.125,
11381  mx_float upper_bound = 0.334) {
11382  static const char *LeakyReLUActTypeValues[] = {
11383  "elu",
11384  "leaky",
11385  "prelu",
11386  "rrelu"
11387  };
11388  return Operator("LeakyReLU")
11389  .SetParam("act_type", LeakyReLUActTypeValues[int(act_type)])
11390  .SetParam("slope", slope)
11391  .SetParam("lower_bound", lower_bound)
11392  .SetParam("upper_bound", upper_bound)
11393  .SetInput("data", data)
11394  .CreateSymbol();
11395 }
11396 
11452  Symbol gamma,
11453  Symbol beta,
11454  mx_float eps = 0.001,
11455  mx_float momentum = 0.9,
11456  bool fix_gamma = true,
11457  bool use_global_stats = false,
11458  bool output_mean_var = false) {
11459  return Operator("BatchNorm_v1")
11460  .SetParam("eps", eps)
11461  .SetParam("momentum", momentum)
11462  .SetParam("fix_gamma", fix_gamma)
11463  .SetParam("use_global_stats", use_global_stats)
11464  .SetParam("output_mean_var", output_mean_var)
11465  .SetInput("data", data)
11466  .SetInput("gamma", gamma)
11467  .SetInput("beta", beta)
11468  .CreateSymbol();
11469 }
11470 
11511 inline Symbol Concat(const std::vector<Symbol>& data,
11512  int num_args,
11513  int dim = 1) {
11514  return Operator("Concat")
11515  .SetParam("num_args", num_args)
11516  .SetParam("dim", dim)
11517 (data)
11518  .CreateSymbol();
11519 }
11520 
11541 inline Symbol sgd_update(Symbol weight,
11542  Symbol grad,
11543  mx_float lr,
11544  mx_float wd = 0,
11545  mx_float rescale_grad = 1,
11546  mx_float clip_gradient = -1) {
11547  return Operator("sgd_update")
11548  .SetParam("lr", lr)
11549  .SetParam("wd", wd)
11550  .SetParam("rescale_grad", rescale_grad)
11551  .SetParam("clip_gradient", clip_gradient)
11552  .SetInput("weight", weight)
11553  .SetInput("grad", grad)
11554  .CreateSymbol();
11555 }
11556 
11592  Symbol grad,
11593  Symbol mom,
11594  mx_float lr,
11595  mx_float momentum = 0,
11596  mx_float wd = 0,
11597  mx_float rescale_grad = 1,
11598  mx_float clip_gradient = -1) {
11599  return Operator("sgd_mom_update")
11600  .SetParam("lr", lr)
11601  .SetParam("momentum", momentum)
11602  .SetParam("wd", wd)
11603  .SetParam("rescale_grad", rescale_grad)
11604  .SetParam("clip_gradient", clip_gradient)
11605  .SetInput("weight", weight)
11606  .SetInput("grad", grad)
11607  .SetInput("mom", mom)
11608  .CreateSymbol();
11609 }
11610 
11625  Symbol grad,
11626  Symbol weight32,
11627  mx_float lr,
11628  mx_float wd = 0,
11629  mx_float rescale_grad = 1,
11630  mx_float clip_gradient = -1) {
11631  return Operator("mp_sgd_update")
11632  .SetParam("lr", lr)
11633  .SetParam("wd", wd)
11634  .SetParam("rescale_grad", rescale_grad)
11635  .SetParam("clip_gradient", clip_gradient)
11636  .SetInput("weight", weight)
11637  .SetInput("grad", grad)
11638  .SetInput("weight32", weight32)
11639  .CreateSymbol();
11640 }
11641 
11658  Symbol grad,
11659  Symbol mom,
11660  Symbol weight32,
11661  mx_float lr,
11662  mx_float momentum = 0,
11663  mx_float wd = 0,
11664  mx_float rescale_grad = 1,
11665  mx_float clip_gradient = -1) {
11666  return Operator("mp_sgd_mom_update")
11667  .SetParam("lr", lr)
11668  .SetParam("momentum", momentum)
11669  .SetParam("wd", wd)
11670  .SetParam("rescale_grad", rescale_grad)
11671  .SetParam("clip_gradient", clip_gradient)
11672  .SetInput("weight", weight)
11673  .SetInput("grad", grad)
11674  .SetInput("mom", mom)
11675  .SetInput("weight32", weight32)
11676  .CreateSymbol();
11677 }
11678 
11717 inline Symbol adam_update(Symbol weight,
11718  Symbol grad,
11719  Symbol mean,
11720  Symbol var,
11721  mx_float lr,
11722  mx_float beta1 = 0.9,
11723  mx_float beta2 = 0.999,
11724  mx_float epsilon = 1e-08,
11725  mx_float wd = 0,
11726  mx_float rescale_grad = 1,
11727  mx_float clip_gradient = -1) {
11728  return Operator("adam_update")
11729  .SetParam("lr", lr)
11730  .SetParam("beta1", beta1)
11731  .SetParam("beta2", beta2)
11732  .SetParam("epsilon", epsilon)
11733  .SetParam("wd", wd)
11734  .SetParam("rescale_grad", rescale_grad)
11735  .SetParam("clip_gradient", clip_gradient)
11736  .SetInput("weight", weight)
11737  .SetInput("grad", grad)
11738  .SetInput("mean", mean)
11739  .SetInput("var", var)
11740  .CreateSymbol();
11741 }
11742 
11796  Symbol grad,
11797  Symbol n,
11798  mx_float lr,
11799  mx_float gamma1 = 0.95,
11800  mx_float epsilon = 1e-08,
11801  mx_float wd = 0,
11802  mx_float rescale_grad = 1,
11803  mx_float clip_gradient = -1,
11804  mx_float clip_weights = -1) {
11805  return Operator("rmsprop_update")
11806  .SetParam("lr", lr)
11807  .SetParam("gamma1", gamma1)
11808  .SetParam("epsilon", epsilon)
11809  .SetParam("wd", wd)
11810  .SetParam("rescale_grad", rescale_grad)
11811  .SetParam("clip_gradient", clip_gradient)
11812  .SetParam("clip_weights", clip_weights)
11813  .SetInput("weight", weight)
11814  .SetInput("grad", grad)
11815  .SetInput("n", n)
11816  .CreateSymbol();
11817 }
11818 
11864  Symbol grad,
11865  Symbol n,
11866  Symbol g,
11867  Symbol delta,
11868  mx_float lr,
11869  mx_float gamma1 = 0.95,
11870  mx_float gamma2 = 0.9,
11871  mx_float epsilon = 1e-08,
11872  mx_float wd = 0,
11873  mx_float rescale_grad = 1,
11874  mx_float clip_gradient = -1,
11875  mx_float clip_weights = -1) {
11876  return Operator("rmspropalex_update")
11877  .SetParam("lr", lr)
11878  .SetParam("gamma1", gamma1)
11879  .SetParam("gamma2", gamma2)
11880  .SetParam("epsilon", epsilon)
11881  .SetParam("wd", wd)
11882  .SetParam("rescale_grad", rescale_grad)
11883  .SetParam("clip_gradient", clip_gradient)
11884  .SetParam("clip_weights", clip_weights)
11885  .SetInput("weight", weight)
11886  .SetInput("grad", grad)
11887  .SetInput("n", n)
11888  .SetInput("g", g)
11889  .SetInput("delta", delta)
11890  .CreateSymbol();
11891 }
11892 
11988 inline Symbol Pad(Symbol data,
11989  PadMode mode,
11990  Shape pad_width,
11991  double constant_value = 0) {
11992  static const char *PadModeValues[] = {
11993  "constant",
11994  "edge",
11995  "reflect"
11996  };
11997  return Operator("Pad")
11998  .SetParam("mode", PadModeValues[int(mode)])
11999  .SetParam("pad_width", pad_width)
12000  .SetParam("constant_value", constant_value)
12001  .SetInput("data", data)
12002  .CreateSymbol();
12003 }
12004 
12014  mx_float sparseness_target = 0.1,
12015  mx_float penalty = 0.001,
12016  mx_float momentum = 0.9) {
12017  return Operator("IdentityAttachKLSparseReg")
12018  .SetParam("sparseness_target", sparseness_target)
12019  .SetParam("penalty", penalty)
12020  .SetParam("momentum", momentum)
12021  .SetInput("data", data)
12022  .CreateSymbol();
12023 }
12024 
12096  int num_outputs,
12097  int axis = 1,
12098  bool squeeze_axis = false) {
12099  return Operator("SliceChannel")
12100  .SetParam("num_outputs", num_outputs)
12101  .SetParam("axis", axis)
12102  .SetParam("squeeze_axis", squeeze_axis)
12103  .SetInput("data", data)
12104  .CreateSymbol();
12105 }
12106 
12144  Symbol label) {
12145  return Operator("softmax_cross_entropy")
12146  .SetInput("data", data)
12147  .SetInput("label", label)
12148  .CreateSymbol();
12149 }
12150 
12165 inline Symbol UpSampling(const std::vector<Symbol>& data,
12166  uint32_t scale,
12167  UpSamplingSampleType sample_type,
12168  int num_args,
12169  uint32_t num_filter = 0,
12171  uint64_t workspace = 512) {
12172  static const char *UpSamplingSampleTypeValues[] = {
12173  "bilinear",
12174  "nearest"
12175  };
12176  static const char *UpSamplingMultiInputModeValues[] = {
12177  "concat",
12178  "sum"
12179  };
12180  return Operator("UpSampling")
12181  .SetParam("scale", scale)
12182  .SetParam("sample_type", UpSamplingSampleTypeValues[int(sample_type)])
12183  .SetParam("num_args", num_args)
12184  .SetParam("num_filter", num_filter)
12185  .SetParam("multi_input_mode", UpSamplingMultiInputModeValues[int(multi_input_mode)])
12186  .SetParam("workspace", workspace)
12187 (data)
12188  .CreateSymbol();
12189 }
12190 
12254  Symbol gamma,
12255  Symbol beta,
12256  Symbol moving_mean,
12257  Symbol moving_var,
12258  double eps = 0.001,
12259  mx_float momentum = 0.9,
12260  bool fix_gamma = true,
12261  bool use_global_stats = false,
12262  bool output_mean_var = false,
12263  int axis = 1,
12264  bool cudnn_off = false) {
12265  return Operator("BatchNorm")
12266  .SetParam("eps", eps)
12267  .SetParam("momentum", momentum)
12268  .SetParam("fix_gamma", fix_gamma)
12269  .SetParam("use_global_stats", use_global_stats)
12270  .SetParam("output_mean_var", output_mean_var)
12271  .SetParam("axis", axis)
12272  .SetParam("cudnn_off", cudnn_off)
12273  .SetInput("data", data)
12274  .SetInput("gamma", gamma)
12275  .SetInput("beta", beta)
12276  .SetInput("moving_mean", moving_mean)
12277  .SetInput("moving_var", moving_var)
12278  .CreateSymbol();
12279 }
12280 
12331  Symbol gamma,
12332  Symbol beta,
12333  mx_float eps = 0.001) {
12334  return Operator("InstanceNorm")
12335  .SetParam("eps", eps)
12336  .SetInput("data", data)
12337  .SetInput("gamma", gamma)
12338  .SetInput("beta", beta)
12339  .CreateSymbol();
12340 }
12341 
12356 inline Symbol RNN(Symbol data,
12357  Symbol parameters,
12358  Symbol state,
12359  Symbol state_cell,
12360  uint32_t state_size,
12361  uint32_t num_layers,
12362  RNNMode mode,
12363  bool bidirectional = false,
12364  mx_float p = 0,
12365  bool state_outputs = false) {
12366  static const char *RNNModeValues[] = {
12367  "gru",
12368  "lstm",
12369  "rnn_relu",
12370  "rnn_tanh"
12371  };
12372  return Operator("RNN")
12373  .SetParam("state_size", state_size)
12374  .SetParam("num_layers", num_layers)
12375  .SetParam("mode", RNNModeValues[int(mode)])
12376  .SetParam("bidirectional", bidirectional)
12377  .SetParam("p", p)
12378  .SetParam("state_outputs", state_outputs)
12379  .SetInput("data", data)
12380  .SetInput("parameters", parameters)
12381  .SetInput("state", state)
12382  .SetInput("state_cell", state_cell)
12383  .CreateSymbol();
12384 }
12385 
12414  Symbol weight,
12415  Symbol bias,
12416  Shape kernel,
12417  uint32_t num_filter,
12418  Shape stride = Shape(),
12419  Shape dilate = Shape(),
12420  Shape pad = Shape(),
12421  uint32_t num_group = 1,
12422  uint64_t workspace = 1024,
12423  bool no_bias = false,
12425  bool cudnn_off = false,
12427  static const char *Convolution_v1CudnnTuneValues[] = {
12428  "None",
12429  "fastest",
12430  "limited_workspace",
12431  "off"
12432  };
12433  static const char *Convolution_v1LayoutValues[] = {
12434  "None",
12435  "NCDHW",
12436  "NCHW",
12437  "NDHWC",
12438  "NHWC"
12439  };
12440  return Operator("Convolution_v1")
12441  .SetParam("kernel", kernel)
12442  .SetParam("num_filter", num_filter)
12443  .SetParam("stride", stride)
12444  .SetParam("dilate", dilate)
12445  .SetParam("pad", pad)
12446  .SetParam("num_group", num_group)
12447  .SetParam("workspace", workspace)
12448  .SetParam("no_bias", no_bias)
12449  .SetParam("cudnn_tune", Convolution_v1CudnnTuneValues[int(cudnn_tune)])
12450  .SetParam("cudnn_off", cudnn_off)
12451  .SetParam("layout", Convolution_v1LayoutValues[int(layout)])
12452  .SetInput("data", data)
12453  .SetInput("weight", weight)
12454  .SetInput("bias", bias)
12455  .CreateSymbol();
12456 }
12457 
12477 inline Symbol Crop(const std::vector<Symbol>& data,
12478  int num_args,
12479  Shape offset = Shape(0,0),
12480  Shape h_w = Shape(0,0),
12481  bool center_crop = false) {
12482  return Operator("Crop")
12483  .SetParam("num_args", num_args)
12484  .SetParam("offset", offset)
12485  .SetParam("h_w", h_w)
12486  .SetParam("center_crop", center_crop)
12487 (data)
12488  .CreateSymbol();
12489 }
12490 
12501  Symbol loc,
12502  SpatialTransformerTransformType transform_type,
12503  SpatialTransformerSamplerType sampler_type,
12504  Shape target_shape = Shape(0,0)) {
12505  static const char *SpatialTransformerTransformTypeValues[] = {
12506  "affine"
12507  };
12508  static const char *SpatialTransformerSamplerTypeValues[] = {
12509  "bilinear"
12510  };
12511  return Operator("SpatialTransformer")
12512  .SetParam("transform_type", SpatialTransformerTransformTypeValues[int(transform_type)])
12513  .SetParam("sampler_type", SpatialTransformerSamplerTypeValues[int(sampler_type)])
12514  .SetParam("target_shape", target_shape)
12515  .SetInput("data", data)
12516  .SetInput("loc", loc)
12517  .CreateSymbol();
12518 }
12519 
12546  Symbol weight,
12547  Symbol bias,
12548  Shape kernel,
12549  uint32_t num_filter,
12550  Shape stride = Shape(),
12551  Shape dilate = Shape(),
12552  Shape pad = Shape(),
12553  Shape adj = Shape(),
12554  Shape target_shape = Shape(),
12555  uint32_t num_group = 1,
12556  uint64_t workspace = 512,
12557  bool no_bias = true,
12559  bool cudnn_off = false,
12561  static const char *DeconvolutionCudnnTuneValues[] = {
12562  "None",
12563  "fastest",
12564  "limited_workspace",
12565  "off"
12566  };
12567  static const char *DeconvolutionLayoutValues[] = {
12568  "None",
12569  "NCDHW",
12570  "NCHW",
12571  "NCW",
12572  "NDHWC",
12573  "NHWC"
12574  };
12575  return Operator("Deconvolution")
12576  .SetParam("kernel", kernel)
12577  .SetParam("num_filter", num_filter)
12578  .SetParam("stride", stride)
12579  .SetParam("dilate", dilate)
12580  .SetParam("pad", pad)
12581  .SetParam("adj", adj)
12582  .SetParam("target_shape", target_shape)
12583  .SetParam("num_group", num_group)
12584  .SetParam("workspace", workspace)
12585  .SetParam("no_bias", no_bias)
12586  .SetParam("cudnn_tune", DeconvolutionCudnnTuneValues[int(cudnn_tune)])
12587  .SetParam("cudnn_off", cudnn_off)
12588  .SetParam("layout", DeconvolutionLayoutValues[int(layout)])
12589  .SetInput("data", data)
12590  .SetInput("weight", weight)
12591  .SetInput("bias", bias)
12592  .CreateSymbol();
12593 }
12594 
12687  Symbol label,
12688  mx_float grad_scale = 1,
12689  mx_float ignore_label = -1,
12690  bool multi_output = false,
12691  bool use_ignore = false,
12692  bool preserve_shape = false,
12694  bool out_grad = false) {
12695  static const char *SoftmaxOutputNormalizationValues[] = {
12696  "batch",
12697  "null",
12698  "valid"
12699  };
12700  return Operator("SoftmaxOutput")
12701  .SetParam("grad_scale", grad_scale)
12702  .SetParam("ignore_label", ignore_label)
12703  .SetParam("multi_output", multi_output)
12704  .SetParam("use_ignore", use_ignore)
12705  .SetParam("preserve_shape", preserve_shape)
12706  .SetParam("normalization", SoftmaxOutputNormalizationValues[int(normalization)])
12707  .SetParam("out_grad", out_grad)
12708  .SetInput("data", data)
12709  .SetInput("label", label)
12710  .CreateSymbol();
12711 }
12712 
12736 inline Symbol Softmax(Symbol data,
12737  mx_float grad_scale = 1,
12738  mx_float ignore_label = -1,
12739  bool multi_output = false,
12740  bool use_ignore = false,
12741  bool preserve_shape = false,
12743  bool out_grad = false) {
12744  static const char *SoftmaxNormalizationValues[] = {
12745  "batch",
12746  "null",
12747  "valid"
12748  };
12749  return Operator("Softmax")
12750  .SetParam("grad_scale", grad_scale)
12751  .SetParam("ignore_label", ignore_label)
12752  .SetParam("multi_output", multi_output)
12753  .SetParam("use_ignore", use_ignore)
12754  .SetParam("preserve_shape", preserve_shape)
12755  .SetParam("normalization", SoftmaxNormalizationValues[int(normalization)])
12756  .SetParam("out_grad", out_grad)
12757  .SetInput("data", data)
12758  .CreateSymbol();
12759 }
12760 
12836  Symbol sequence_length,
12837  bool use_sequence_length = false) {
12838  return Operator("SequenceReverse")
12839  .SetParam("use_sequence_length", use_sequence_length)
12840  .SetInput("data", data)
12841  .SetInput("sequence_length", sequence_length)
12842  .CreateSymbol();
12843 }
12844 
12899  Symbol sequence_length,
12900  bool use_sequence_length = false) {
12901  return Operator("SequenceLast")
12902  .SetParam("use_sequence_length", use_sequence_length)
12903  .SetInput("data", data)
12904  .SetInput("sequence_length", sequence_length)
12905  .CreateSymbol();
12906 }
12907 
12956  Symbol data2,
12957  uint32_t kernel_size = 1,
12958  uint32_t max_displacement = 1,
12959  uint32_t stride1 = 1,
12960  uint32_t stride2 = 1,
12961  uint32_t pad_size = 0,
12962  bool is_multiply = true) {
12963  return Operator("Correlation")
12964  .SetParam("kernel_size", kernel_size)
12965  .SetParam("max_displacement", max_displacement)
12966  .SetParam("stride1", stride1)
12967  .SetParam("stride2", stride2)
12968  .SetParam("pad_size", pad_size)
12969  .SetParam("is_multiply", is_multiply)
12970  .SetInput("data1", data1)
12971  .SetInput("data2", data2)
12972  .CreateSymbol();
12973 }
12974 
12990  Symbol label,
12991  mx_float margin = 1,
12992  mx_float regularization_coefficient = 1,
12993  bool use_linear = false) {
12994  return Operator("SVMOutput")
12995  .SetParam("margin", margin)
12996  .SetParam("regularization_coefficient", regularization_coefficient)
12997  .SetParam("use_linear", use_linear)
12998  .SetInput("data", data)
12999  .SetInput("label", label)
13000  .CreateSymbol();
13001 }
13002 
13065  mx_float eps = 1e-10,
13067  static const char *L2NormalizationModeValues[] = {
13068  "channel",
13069  "instance",
13070  "spatial"
13071  };
13072  return Operator("L2Normalization")
13073  .SetParam("eps", eps)
13074  .SetParam("mode", L2NormalizationModeValues[int(mode)])
13075  .SetInput("data", data)
13076  .CreateSymbol();
13077 }
13078 
13105 inline Symbol LRN(Symbol data,
13106  uint32_t nsize,
13107  mx_float alpha = 0.0001,
13108  mx_float beta = 0.75,
13109  mx_float knorm = 2) {
13110  return Operator("LRN")
13111  .SetParam("nsize", nsize)
13112  .SetParam("alpha", alpha)
13113  .SetParam("beta", beta)
13114  .SetParam("knorm", knorm)
13115  .SetInput("data", data)
13116  .CreateSymbol();
13117 }
13118 
13144  Symbol weight,
13145  Symbol bias,
13146  int num_hidden,
13147  bool no_bias = false) {
13148  return Operator("FullyConnected")
13149  .SetParam("num_hidden", num_hidden)
13150  .SetParam("no_bias", no_bias)
13151  .SetInput("data", data)
13152  .SetInput("weight", weight)
13153  .SetInput("bias", bias)
13154  .CreateSymbol();
13155 }
13156 
13234  Symbol sequence_length,
13235  bool use_sequence_length = false,
13236  mx_float value = 0) {
13237  return Operator("SequenceMask")
13238  .SetParam("use_sequence_length", use_sequence_length)
13239  .SetParam("value", value)
13240  .SetInput("data", data)
13241  .SetInput("sequence_length", sequence_length)
13242  .CreateSymbol();
13243 }
13244 
13255  GridGeneratorTransformType transform_type,
13256  Shape target_shape = Shape(0,0)) {
13257  static const char *GridGeneratorTransformTypeValues[] = {
13258  "affine",
13259  "warp"
13260  };
13261  return Operator("GridGenerator")
13262  .SetParam("transform_type", GridGeneratorTransformTypeValues[int(transform_type)])
13263  .SetParam("target_shape", target_shape)
13264  .SetInput("data", data)
13265  .CreateSymbol();
13266 }
13267 
13319  Shape kernel,
13320  Pooling_v1PoolType pool_type,
13321  bool global_pool = false,
13323  Shape stride = Shape(),
13324  Shape pad = Shape()) {
13325  static const char *Pooling_v1PoolTypeValues[] = {
13326  "avg",
13327  "max",
13328  "sum"
13329  };
13330  static const char *Pooling_v1PoolingConventionValues[] = {
13331  "full",
13332  "valid"
13333  };
13334  return Operator("Pooling_v1")
13335  .SetParam("kernel", kernel)
13336  .SetParam("pool_type", Pooling_v1PoolTypeValues[int(pool_type)])
13337  .SetParam("global_pool", global_pool)
13338  .SetParam("pooling_convention", Pooling_v1PoolingConventionValues[int(pooling_convention)])
13339  .SetParam("stride", stride)
13340  .SetParam("pad", pad)
13341  .SetInput("data", data)
13342  .CreateSymbol();
13343 }
13344 
13439  Symbol weight,
13440  Symbol bias,
13441  Shape kernel,
13442  uint32_t num_filter,
13443  Shape stride = Shape(),
13444  Shape dilate = Shape(),
13445  Shape pad = Shape(),
13446  uint32_t num_group = 1,
13447  uint64_t workspace = 1024,
13448  bool no_bias = false,
13450  bool cudnn_off = false,
13452  static const char *ConvolutionCudnnTuneValues[] = {
13453  "None",
13454  "fastest",
13455  "limited_workspace",
13456  "off"
13457  };
13458  static const char *ConvolutionLayoutValues[] = {
13459  "None",
13460  "NCDHW",
13461  "NCHW",
13462  "NCW",
13463  "NDHWC",
13464  "NHWC"
13465  };
13466  return Operator("Convolution")
13467  .SetParam("kernel", kernel)
13468  .SetParam("num_filter", num_filter)
13469  .SetParam("stride", stride)
13470  .SetParam("dilate", dilate)
13471  .SetParam("pad", pad)
13472  .SetParam("num_group", num_group)
13473  .SetParam("workspace", workspace)
13474  .SetParam("no_bias", no_bias)
13475  .SetParam("cudnn_tune", ConvolutionCudnnTuneValues[int(cudnn_tune)])
13476  .SetParam("cudnn_off", cudnn_off)
13477  .SetParam("layout", ConvolutionLayoutValues[int(layout)])
13478  .SetInput("data", data)
13479  .SetInput("weight", weight)
13480  .SetInput("bias", bias)
13481  .CreateSymbol();
13482 }
13483 
13564  Symbol grid) {
13565  return Operator("BilinearSampler")
13566  .SetInput("data", data)
13567  .SetInput("grid", grid)
13568  .CreateSymbol();
13569 }
13570 
13623 inline Symbol Pooling(Symbol data,
13624  Shape kernel,
13625  PoolingPoolType pool_type,
13626  bool global_pool = false,
13627  bool cudnn_off = false,
13629  Shape stride = Shape(),
13630  Shape pad = Shape()) {
13631  static const char *PoolingPoolTypeValues[] = {
13632  "avg",
13633  "max",
13634  "sum"
13635  };
13636  static const char *PoolingPoolingConventionValues[] = {
13637  "full",
13638  "valid"
13639  };
13640  return Operator("Pooling")
13641  .SetParam("kernel", kernel)
13642  .SetParam("pool_type", PoolingPoolTypeValues[int(pool_type)])
13643  .SetParam("global_pool", global_pool)
13644  .SetParam("cudnn_off", cudnn_off)
13645  .SetParam("pooling_convention", PoolingPoolingConventionValues[int(pooling_convention)])
13646  .SetParam("stride", stride)
13647  .SetParam("pad", pad)
13648  .SetInput("data", data)
13649  .CreateSymbol();
13650 }
13651 
13690 inline Symbol Dropout(Symbol data,
13691  mx_float p = 0.5,
13693  static const char *DropoutModeValues[] = {
13694  "always",
13695  "training"
13696  };
13697  return Operator("Dropout")
13698  .SetParam("p", p)
13699  .SetParam("mode", DropoutModeValues[int(mode)])
13700  .SetInput("data", data)
13701  .CreateSymbol();
13702 }
13703 
13722  ActivationActType act_type) {
13723  static const char *ActivationActTypeValues[] = {
13724  "relu",
13725  "sigmoid",
13726  "softrelu",
13727  "tanh"
13728  };
13729  return Operator("Activation")
13730  .SetParam("act_type", ActivationActTypeValues[int(act_type)])
13731  .SetInput("data", data)
13732  .CreateSymbol();
13733 }
13734 
13791  Symbol rois,
13792  Shape pooled_size,
13793  mx_float spatial_scale) {
13794  return Operator("ROIPooling")
13795  .SetParam("pooled_size", pooled_size)
13796  .SetParam("spatial_scale", spatial_scale)
13797  .SetInput("data", data)
13798  .SetInput("rois", rois)
13799  .CreateSymbol();
13800 }
13801 
13826  Symbol label,
13827  mx_float grad_scale = 1) {
13828  return Operator("LinearRegressionOutput")
13829  .SetParam("grad_scale", grad_scale)
13830  .SetInput("data", data)
13831  .SetInput("label", label)
13832  .CreateSymbol();
13833 }
13834 
13860  Symbol label,
13861  mx_float grad_scale = 1) {
13862  return Operator("MAERegressionOutput")
13863  .SetParam("grad_scale", grad_scale)
13864  .SetInput("data", data)
13865  .SetInput("label", label)
13866  .CreateSymbol();
13867 }
13868 
13894  Symbol label,
13895  mx_float grad_scale = 1) {
13896  return Operator("LogisticRegressionOutput")
13897  .SetParam("grad_scale", grad_scale)
13898  .SetInput("data", data)
13899  .SetInput("label", label)
13900  .CreateSymbol();
13901 }
13902 
13937  static const char *SoftmaxActivationModeValues[] = {
13938  "channel",
13939  "instance"
13940  };
13941  return Operator("SoftmaxActivation")
13942  .SetParam("mode", SoftmaxActivationModeValues[int(mode)])
13943  .SetInput("data", data)
13944  .CreateSymbol();
13945 }
13946 
13980 inline Symbol MakeLoss(Symbol data,
13981  mx_float grad_scale = 1,
13982  mx_float valid_thresh = 0,
13984  static const char *MakeLossNormalizationValues[] = {
13985  "batch",
13986  "null",
13987  "valid"
13988  };
13989  return Operator("MakeLoss")
13990  .SetParam("grad_scale", grad_scale)
13991  .SetParam("valid_thresh", valid_thresh)
13992  .SetParam("normalization", MakeLossNormalizationValues[int(normalization)])
13993  .SetInput("data", data)
13994  .CreateSymbol();
13995 }
13996 
14005  Symbol rhs) {
14006  return Operator("choose_element_0index")
14007  .SetInput("lhs", lhs)
14008  .SetInput("rhs", rhs)
14009  .CreateSymbol();
14010 }
14011 
14021  Symbol mhs,
14022  Symbol rhs) {
14023  return Operator("fill_element_0index")
14024  .SetInput("lhs", lhs)
14025  .SetInput("mhs", mhs)
14026  .SetInput("rhs", rhs)
14027  .CreateSymbol();
14028 }
14029 
14030 } //namespace cpp
14031 } //namespace mxnet
14032 #endif // MXNET_CPP_OP_H_
Symbol Convolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, ConvolutionCudnnTune cudnn_tune=ConvolutionCudnnTune::kNone, bool cudnn_off=false, ConvolutionLayout layout=ConvolutionLayout::kNone)
Definition: op.h:6729
Symbol sample_multinomial(const std::string &symbol_name, Symbol data, Shape shape=Shape(), bool get_prob=false, Sample_multinomialDtype dtype=Sample_multinomialDtype::kInt32)
Definition: op.h:135
Symbol fix(const std::string &symbol_name, Symbol data)
Definition: op.h:2959
Symbol Crop(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, Shape offset=Shape(0, 0), Shape h_w=Shape(0, 0), bool center_crop=false)
Definition: op.h:5638
Symbol min(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2319
Symbol broadcast_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1786
Symbol linalg_gemm(const std::string &symbol_name, Symbol A, Symbol B, Symbol C, bool transpose_a=false, bool transpose_b=false, double alpha=1, double beta=1)
Definition: op.h:3964
Symbol arcsin(const std::string &symbol_name, Symbol data)
Definition: op.h:3240
Symbol arccosh(const std::string &symbol_name, Symbol data)
Definition: op.h:3417
Symbol arctan(const std::string &symbol_name, Symbol data)
Definition: op.h:3285
Symbol SwapAxis(const std::string &symbol_name, Symbol data, uint32_t dim1=0, uint32_t dim2=0)
Definition: op.h:4394
Symbol nansum(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2206
Symbol add_n(const std::string &symbol_name, const std::vector< Symbol > &args)
Definition: op.h:1871
Symbol log1p(const std::string &symbol_name, Symbol data)
Definition: op.h:3154
SoftmaxActivationMode
Definition: op.h:7245
Symbol mp_sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol weight32, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4697
Symbol SpatialTransformer(const std::string &symbol_name, Symbol data, Symbol loc, SpatialTransformerTransformType transform_type, SpatialTransformerSamplerType sampler_type, Shape target_shape=Shape(0, 0))
Definition: op.h:5675
Sample_uniformDtype
Definition: op.h:153
Symbol exp(const std::string &symbol_name, Symbol data)
Definition: op.h:3055
Symbol transpose(const std::string &symbol_name, Symbol data, Shape axes=Shape())
Definition: op.h:1244
Symbol clip(const std::string &symbol_name, Symbol data, mx_float a_min, mx_float a_max)
Definition: op.h:1468
Symbol ROIPooling(const std::string &symbol_name, Symbol data, Symbol rois, Shape pooled_size, mx_float spatial_scale)
Definition: op.h:7122
Symbol broadcast_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1817
Random_poissonDtype
Definition: op.h:798
Random_negative_binomialDtype
Definition: op.h:845
Symbol nanprod(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2245
Convolution_v1Layout
Definition: op.h:5536
Symbol argmin(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1951
Symbol linalg_potri(const std::string &symbol_name, Symbol A)
Definition: op.h:4116
Symbol linalg_trmm(const std::string &symbol_name, Symbol A, Symbol B, bool transpose=false, bool rightside=false, double alpha=1)
Definition: op.h:4170
Symbol mp_sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, Symbol weight32, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4732
Symbol broadcast_lesser(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3879
Symbol fill_element_0index(const std::string &symbol_name, Symbol lhs, Symbol mhs, Symbol rhs)
Definition: op.h:7384
Symbol Convolution_v1(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, Convolution_v1CudnnTune cudnn_tune=Convolution_v1CudnnTune::kNone, bool cudnn_off=false, Convolution_v1Layout layout=Convolution_v1Layout::kNone)
Definition: op.h:5572
Symbol broadcast_not_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3786
TakeMode
Definition: op.h:3557
Symbol Embedding(const std::string &symbol_name, Symbol data, Symbol weight, int input_dim, int output_dim, EmbeddingDtype dtype=EmbeddingDtype::kFloat32)
Definition: op.h:3531
Symbol reciprocal(const std::string &symbol_name, Symbol data)
Definition: op.h:2780
TopkRetTyp
Definition: op.h:2438
Symbol RNN(const std::string &symbol_name, Symbol data, Symbol parameters, Symbol state, Symbol state_cell, uint32_t state_size, uint32_t num_layers, RNNMode mode, bool bidirectional=false, mx_float p=0, bool state_outputs=false)
Definition: op.h:5486
namespace of mxnet
Definition: base.h:126
Sample_exponentialDtype
Definition: op.h:345
Pooling_v1PoolingConvention
Definition: op.h:6530
Symbol broadcast_lesser_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3910
Operator & SetInput(const std::string &name, Symbol symbol)
add an input symbol
Symbol InstanceNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001)
Definition: op.h:5449
Symbol sign(const std::string &symbol_name, Symbol data)
Definition: op.h:2822
GridGeneratorTransformType
Definition: op.h:6490
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:42
Symbol ones_like(const std::string &symbol_name, Symbol data)
Definition: op.h:1685
RNNMode
Definition: op.h:5464
PadMode
Definition: op.h:4978
Symbol smooth_l1(const std::string &symbol_name, Symbol data, mx_float scalar)
Definition: op.h:4334
Symbol where(const std::string &symbol_name, Symbol condition, Symbol x, Symbol y)
Definition: op.h:4298
Symbol Dropout(const std::string &symbol_name, Symbol data, mx_float p=0.5, DropoutMode mode=DropoutMode::kTraining)
Definition: op.h:7009
Symbol expm1(const std::string &symbol_name, Symbol data)
Definition: op.h:3173
Symbol elemwise_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2598
PoolingPoolType
Definition: op.h:6867
Symbol relu(const std::string &symbol_name, Symbol data)
Definition: op.h:2620
Symbol reverse(const std::string &symbol_name, Symbol data, Shape axis)
Definition: op.h:1601
Symbol rsqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:3031
Symbol Pooling(const std::string &symbol_name, Symbol data, Shape kernel, PoolingPoolType pool_type, bool global_pool=false, bool cudnn_off=false, PoolingPoolingConvention pooling_convention=PoolingPoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:6933
Symbol sample_exponential(const std::string &symbol_name, Symbol lam, Shape shape=Shape(), Sample_exponentialDtype dtype=Sample_exponentialDtype::kNone)
Definition: op.h:386
Symbol sample_poisson(const std::string &symbol_name, Symbol lam, Shape shape=Shape(), Sample_poissonDtype dtype=Sample_poissonDtype::kNone)
Definition: op.h:448
Symbol Pooling_v1(const std::string &symbol_name, Symbol data, Shape kernel, Pooling_v1PoolType pool_type, bool global_pool=false, Pooling_v1PoolingConvention pooling_convention=Pooling_v1PoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:6586
SpatialTransformerTransformType
Definition: op.h:5655
Symbol random_exponential(const std::string &symbol_name, mx_float lam=1, Shape shape=Shape(), const std::string &ctx="", Random_exponentialDtype dtype=Random_exponentialDtype::kNone)
Definition: op.h:778
ActivationActType
Definition: op.h:7026
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:3007
Symbol random_normal(const std::string &symbol_name, mx_float loc=0, mx_float scale=1, Shape shape=Shape(), const std::string &ctx="", Random_normalDtype dtype=Random_normalDtype::kNone)
Definition: op.h:681
Symbol rint(const std::string &symbol_name, Symbol data)
Definition: op.h:2868
Symbol IdentityAttachKLSparseReg(const std::string &symbol_name, Symbol data, mx_float sparseness_target=0.1, mx_float penalty=0.001, mx_float momentum=0.9)
Definition: op.h:5107
Symbol sinh(const std::string &symbol_name, Symbol data)
Definition: op.h:3345
Symbol broadcast_greater_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3848
Symbol LRN(const std::string &symbol_name, Symbol data, uint32_t nsize, mx_float alpha=0.0001, mx_float beta=0.75, mx_float knorm=2)
Definition: op.h:6342
Symbol sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4662
Symbol max(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2282
Symbol arcsinh(const std::string &symbol_name, Symbol data)
Definition: op.h:3401
Random_gammaDtype
Definition: op.h:703
Symbol MAERegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:7195
Symbol sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4610
Symbol SliceChannel(const std::string &symbol_name, Symbol data, int num_outputs, int axis=1, bool squeeze_axis=false)
Definition: op.h:5191
PoolingPoolingConvention
Definition: op.h:6875
Symbol broadcast_minimum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1034
Sample_gammaDtype
Definition: op.h:281
Symbol broadcast_maximum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1001
Random_generalized_negative_binomialDtype
Definition: op.h:896
Symbol Cast(const std::string &symbol_name, Symbol data, CastDtype dtype)
Definition: op.h:2733
DeconvolutionLayout
Definition: op.h:5707
Symbol trunc(const std::string &symbol_name, Symbol data)
Definition: op.h:2938
Pooling_v1PoolType
Definition: op.h:6522
Symbol sample_generalized_negative_binomial(const std::string &symbol_name, Symbol mu, Symbol alpha, Shape shape=Shape(), Sample_generalized_negative_binomialDtype dtype=Sample_generalized_negative_binomialDtype::kNone)
Definition: op.h:578
Symbol round(const std::string &symbol_name, Symbol data)
Definition: op.h:2843
Sample_normalDtype
Definition: op.h:217
Symbol log_softmax(const std::string &symbol_name, Symbol data, int axis=-1)
Definition: op.h:82
Symbol cos(const std::string &symbol_name, Symbol data)
Definition: op.h:3195
Symbol sample_uniform(const std::string &symbol_name, Symbol low, Symbol high, Shape shape=Shape(), Sample_uniformDtype dtype=Sample_uniformDtype::kNone)
Definition: op.h:196
Symbol SequenceMask(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, mx_float value=0)
Definition: op.h:6474
Symbol L2Normalization(const std::string &symbol_name, Symbol data, mx_float eps=1e-10, L2NormalizationMode mode=L2NormalizationMode::kInstance)
Definition: op.h:6299
Symbol Correlation(const std::string &symbol_name, Symbol data1, Symbol data2, uint32_t kernel_size=1, uint32_t max_displacement=1, uint32_t stride1=1, uint32_t stride2=1, uint32_t pad_size=0, bool is_multiply=true)
Definition: op.h:6178
Symbol random_gamma(const std::string &symbol_name, mx_float alpha=1, mx_float beta=1, Shape shape=Shape(), const std::string &ctx="", Random_gammaDtype dtype=Random_gammaDtype::kNone)
Definition: op.h:730
Symbol zeros_like(const std::string &symbol_name, Symbol data)
Definition: op.h:1661
EmbeddingDtype
Definition: op.h:3472
Symbol batch_dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false)
Definition: op.h:1432
Symbol broadcast_mod(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1848
Symbol FullyConnected(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, int num_hidden, bool no_bias=false)
Definition: op.h:6382
Symbol prod(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2167
operator helper functions
Symbol mean(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2130
Symbol tanh(const std::string &symbol_name, Symbol data)
Definition: op.h:3385
Symbol random_poisson(const std::string &symbol_name, mx_float lam=1, Shape shape=Shape(), const std::string &ctx="", Random_poissonDtype dtype=Random_poissonDtype::kNone)
Definition: op.h:825
Symbol broadcast_to(const std::string &symbol_name, Symbol data, Shape shape=Shape())
Definition: op.h:2400
Symbol sample_negative_binomial(const std::string &symbol_name, Symbol k, Symbol p, Shape shape=Shape(), Sample_negative_binomialDtype dtype=Sample_negative_binomialDtype::kNone)
Definition: op.h:512
DropoutMode
Definition: op.h:6965
Symbol MakeLoss(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float valid_thresh=0, MakeLossNormalization normalization=MakeLossNormalization::kNull)
Definition: op.h:7340
Symbol log(const std::string &symbol_name, Symbol data)
Definition: op.h:3074
Symbol sigmoid(const std::string &symbol_name, Symbol data)
Definition: op.h:2640
Symbol SequenceReverse(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false)
Definition: op.h:6054
CastDtype
Definition: op.h:2706
ConvolutionLayout
Definition: op.h:6626
Random_exponentialDtype
Definition: op.h:752
Symbol LogisticRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:7231
Symbol gamma(const std::string &symbol_name, Symbol data)
Definition: op.h:3448
Symbol sin(const std::string &symbol_name, Symbol data)
Definition: op.h:3134
Random_uniformDtype
Definition: op.h:599
UpSamplingMultiInputMode
Definition: op.h:5260
Symbol CreateSymbol(const std::string &name="")
create a Symbol from the current operator
SpatialTransformerSamplerType
Definition: op.h:5661
Symbol Pad(const std::string &symbol_name, Symbol data, PadMode mode, Shape pad_width, double constant_value=0)
Definition: op.h:5080
Symbol square(const std::string &symbol_name, Symbol data)
Definition: op.h:2983
One_hotDtype
Definition: op.h:3659
UpSamplingSampleType
Definition: op.h:5252
Sample_poissonDtype
Definition: op.h:405
Symbol random_generalized_negative_binomial(const std::string &symbol_name, mx_float mu=1, mx_float alpha=1, Shape shape=Shape(), const std::string &ctx="", Random_generalized_negative_binomialDtype dtype=Random_generalized_negative_binomialDtype::kNone)
Definition: op.h:926
Symbol norm(const std::string &symbol_name, Symbol data)
Definition: op.h:2426
Symbol rmsprop_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, mx_float lr, mx_float gamma1=0.95, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:4874
Symbol LeakyReLU(const std::string &symbol_name, Symbol data, LeakyReLUActType act_type=LeakyReLUActType::kLeaky, mx_float slope=0.25, mx_float lower_bound=0.125, mx_float upper_bound=0.334)
Definition: op.h:4440
Symbol make_loss(const std::string &symbol_name, Symbol data)
Definition: op.h:2697
Symbol SoftmaxOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxOutputNormalization normalization=SoftmaxOutputNormalization::kNull, bool out_grad=false)
Definition: op.h:5893
Symbol SoftmaxActivation(const std::string &symbol_name, Symbol data, SoftmaxActivationMode mode=SoftmaxActivationMode::kInstance)
Definition: op.h:7283
Symbol Softmax(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxNormalization normalization=SoftmaxNormalization::kNull, bool out_grad=false)
Definition: op.h:5953
Symbol slice(const std::string &symbol_name, Symbol data, Shape begin, Shape end)
Definition: op.h:1310
Symbol broadcast_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3755
Symbol Deconvolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), Shape adj=Shape(), Shape target_shape=Shape(), uint32_t num_group=1, uint64_t workspace=512, bool no_bias=true, DeconvolutionCudnnTune cudnn_tune=DeconvolutionCudnnTune::kNone, bool cudnn_off=false, DeconvolutionLayout layout=DeconvolutionLayout::kNone)
Definition: op.h:5742
Symbol broadcast_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1719
Symbol linalg_trsm(const std::string &symbol_name, Symbol A, Symbol B, bool transpose=false, bool rightside=false, double alpha=1)
Definition: op.h:4232
Symbol adam_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mean, Symbol var, mx_float lr, mx_float beta1=0.9, mx_float beta2=0.999, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4794
Sample_multinomialDtype
Definition: op.h:93
Operator & SetParam(const std::string &name, const T &value)
set config parameters
Definition: operator.h:57
Symbol tan(const std::string &symbol_name, Symbol data)
Definition: op.h:3217
Convolution_v1CudnnTune
Definition: op.h:5526
Symbol repeat(const std::string &symbol_name, Symbol data, int repeats, dmlc::optional< int > axis=dmlc::optional< int >())
Definition: op.h:1513
Symbol slice_axis(const std::string &symbol_name, Symbol data, int axis, int begin, dmlc::optional< int > end)
Definition: op.h:1353
Symbol expand_dims(const std::string &symbol_name, Symbol data, int axis)
Definition: op.h:1268
Symbol arctanh(const std::string &symbol_name, Symbol data)
Definition: op.h:3433
Symbol softmax_cross_entropy(const std::string &symbol_name, Symbol data, Symbol label)
Definition: op.h:5241
Symbol pick(const std::string &symbol_name, Symbol data, Symbol index, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:2036
Symbol broadcast_axis(const std::string &symbol_name, Symbol data, Shape axis=Shape(), Shape size=Shape())
Definition: op.h:2361
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:2801
Symbol sample_normal(const std::string &symbol_name, Symbol mu, Symbol sigma, Shape shape=Shape(), Sample_normalDtype dtype=Sample_normalDtype::kNone)
Definition: op.h:260
Symbol cosh(const std::string &symbol_name, Symbol data)
Definition: op.h:3365
Symbol sort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Definition: op.h:2539
Symbol BilinearSampler(const std::string &symbol_name, Symbol data, Symbol grid)
Definition: op.h:6856
Symbol Custom(const std::string &symbol_name, const std::vector< Symbol > &data, const std::string &op_type)
Definition: op.h:4358
Symbol linalg_sumlogdiag(const std::string &symbol_name, Symbol A)
Definition: op.h:4275
Symbol broadcast_hypot(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1073
Symbol BatchNorm_v1(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false)
Definition: op.h:4516
Symbol UpSampling(const std::string &symbol_name, const std::vector< Symbol > &data, uint32_t scale, UpSamplingSampleType sample_type, int num_args, uint32_t num_filter=0, UpSamplingMultiInputMode multi_input_mode=UpSamplingMultiInputMode::kConcat, uint64_t workspace=512)
Definition: op.h:5280
Symbol Activation(const std::string &symbol_name, Symbol data, ActivationActType act_type)
Definition: op.h:7051
float mx_float
manually define float
Definition: c_api.h:59
Symbol SVMOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float margin=1, mx_float regularization_coefficient=1, bool use_linear=false)
Definition: op.h:6214
Symbol radians(const std::string &symbol_name, Symbol data)
Definition: op.h:3325
Symbol Concat(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int dim=1)
Definition: op.h:4578
L2NormalizationMode
Definition: op.h:6231
Symbol stack(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int axis=0)
Definition: op.h:1633
Symbol floor(const std::string &symbol_name, Symbol data)
Definition: op.h:2914
Symbol broadcast_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1755
Symbol linalg_potrf(const std::string &symbol_name, Symbol A)
Definition: op.h:4074
Symbol take(const std::string &symbol_name, Symbol a, Symbol indices, int axis=0, TakeMode mode=TakeMode::kClip)
Definition: op.h:3602
Symbol ceil(const std::string &symbol_name, Symbol data)
Definition: op.h:2891
Symbol gammaln(const std::string &symbol_name, Symbol data)
Definition: op.h:3463
Symbol tile(const std::string &symbol_name, Symbol data, Shape reps)
Definition: op.h:1569
Symbol rmspropalex_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, Symbol g, Symbol delta, mx_float lr, mx_float gamma1=0.95, mx_float gamma2=0.9, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:4944
Sample_negative_binomialDtype
Definition: op.h:467
Symbol argsort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Definition: op.h:2580
SoftmaxNormalization
Definition: op.h:5923
Symbol softmax(const std::string &symbol_name, Symbol data, int axis=-1)
Definition: op.h:51
DeconvolutionCudnnTune
Definition: op.h:5698
ConvolutionCudnnTune
Definition: op.h:6616
definition of shape
Symbol broadcast_greater(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3817
Symbol BatchNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, Symbol moving_mean, Symbol moving_var, double eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false, int axis=1, bool cudnn_off=false)
Definition: op.h:5370
Symbol topk(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), int k=1, TopkRetTyp ret_typ=TopkRetTyp::kIndices, bool is_ascend=false)
Definition: op.h:2486
Symbol random_uniform(const std::string &symbol_name, mx_float low=0, mx_float high=1, Shape shape=Shape(), const std::string &ctx="", Random_uniformDtype dtype=Random_uniformDtype::kNone)
Definition: op.h:630
Symbol broadcast_power(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:968
Symbol sample_gamma(const std::string &symbol_name, Symbol alpha, Symbol beta, Shape shape=Shape(), Sample_gammaDtype dtype=Sample_gammaDtype::kNone)
Definition: op.h:324
SoftmaxOutputNormalization
Definition: op.h:5795
Symbol Flatten(const std::string &symbol_name, Symbol data)
Definition: op.h:1201
Symbol BlockGrad(const std::string &symbol_name, Symbol data)
Definition: op.h:2680
LeakyReLUActType
Definition: op.h:4407
Symbol arccos(const std::string &symbol_name, Symbol data)
Definition: op.h:3263
Symbol argmax_channel(const std::string &symbol_name, Symbol data)
Definition: op.h:1984
Symbol batch_take(const std::string &symbol_name, Symbol a, Symbol indices)
Definition: op.h:3648
Symbol LinearRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:7159
Symbol random_negative_binomial(const std::string &symbol_name, int k=1, mx_float p=1, Shape shape=Shape(), const std::string &ctx="", Random_negative_binomialDtype dtype=Random_negative_binomialDtype::kNone)
Definition: op.h:874
Sample_generalized_negative_binomialDtype
Definition: op.h:533
Symbol choose_element_0index(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:7366
Symbol Reshape(const std::string &symbol_name, Symbol data, Shape shape=Shape(), bool reverse=false, Shape target_shape=Shape(), bool keep_highest=false)
Definition: op.h:1156
Symbol degrees(const std::string &symbol_name, Symbol data)
Definition: op.h:3305
Symbol one_hot(const std::string &symbol_name, Symbol indices, int depth, double on_value=1, double off_value=0, One_hotDtype dtype=One_hotDtype::kFloat32)
Definition: op.h:3711
Symbol SequenceLast(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false)
Definition: op.h:6119
Random_normalDtype
Definition: op.h:652
Symbol negative(const std::string &symbol_name, Symbol data)
Definition: op.h:2757
Symbol GridGenerator(const std::string &symbol_name, Symbol data, GridGeneratorTransformType transform_type, Shape target_shape=Shape(0, 0))
Definition: op.h:6505
Symbol linalg_gemm2(const std::string &symbol_name, Symbol A, Symbol B, bool transpose_a=false, bool transpose_b=false, double alpha=1)
Definition: op.h:4024
Symbol argmax(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1909
Operator interface.
Definition: operator.h:42
Symbol interface.
Definition: symbol.h:71
Symbol sum(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2093
MakeLossNormalization
Definition: op.h:7300
Symbol log10(const std::string &symbol_name, Symbol data)
Definition: op.h:3093
Symbol dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false)
Definition: op.h:1397
Symbol log2(const std::string &symbol_name, Symbol data)
Definition: op.h:3112