8 #ifndef MXNET_CPP_OP_H_ 9 #define MXNET_CPP_OP_H_ 17 #include "dmlc/optional.h" 63 const std::vector<Symbol>& args) {
66 .CreateSymbol(symbol_name);
85 const std::vector<Symbol>& data,
86 const std::string& op_type) {
89 .CreateSymbol(symbol_name);
150 return Operator(
"broadcast_maximum")
183 return Operator(
"broadcast_minimum")
307 bool keep_highest =
false) {
311 .
SetParam(
"target_shape", target_shape)
312 .
SetParam(
"keep_highest", keep_highest)
527 dmlc::optional<int> end) {
618 dmlc::optional<int> axis = dmlc::optional<int>()) {
736 const std::vector<Symbol>& data,
743 .CreateSymbol(symbol_name);
986 const std::vector<Symbol>& args) {
989 .CreateSymbol(symbol_name);
1025 dmlc::optional<int> axis = dmlc::optional<int>(),
1026 bool keepdims =
false) {
1067 dmlc::optional<int> axis = dmlc::optional<int>(),
1068 bool keepdims =
false) {
1153 dmlc::optional<int> axis = dmlc::optional<int>(),
1154 bool keepdims =
false) {
1206 bool transpose_a =
false,
1207 bool transpose_b =
false) {
1209 .
SetParam(
"transpose_a", transpose_a)
1210 .
SetParam(
"transpose_b", transpose_b)
1241 bool transpose_a =
false,
1242 bool transpose_b =
false) {
1244 .
SetParam(
"transpose_a", transpose_a)
1245 .
SetParam(
"transpose_b", transpose_b)
1422 static const char *CastDtypeValues[] = {
1430 .
SetParam(
"dtype", CastDtypeValues[
int(dtype)])
2067 bool keepdims =
false,
2068 bool exclude =
false) {
2104 bool keepdims =
false,
2105 bool exclude =
false) {
2141 bool keepdims =
false,
2142 bool exclude =
false) {
2180 bool keepdims =
false,
2181 bool exclude =
false) {
2219 bool keepdims =
false,
2220 bool exclude =
false) {
2256 bool keepdims =
false,
2257 bool exclude =
false) {
2293 bool keepdims =
false,
2294 bool exclude =
false) {
2467 dmlc::optional<int> axis = dmlc::optional<int>(-1),
2470 bool is_ascend =
false) {
2471 static const char *TopkRetTypValues[] = {
2480 .
SetParam(
"ret_typ", TopkRetTypValues[
int(ret_typ)])
2520 dmlc::optional<int> axis = dmlc::optional<int>(-1),
2521 bool is_ascend =
true) {
2561 dmlc::optional<int> axis = dmlc::optional<int>(-1),
2562 bool is_ascend =
true) {
2732 static const char *EmbeddingDtypeValues[] = {
2741 .
SetParam(
"output_dim", output_dim)
2742 .
SetParam(
"dtype", EmbeddingDtypeValues[
int(dtype)])
2802 static const char *TakeModeValues[] = {
2809 .
SetParam(
"mode", TakeModeValues[
int(mode)])
2909 double on_value = 1,
2910 double off_value = 0,
2912 static const char *One_hotDtypeValues[] = {
2923 .
SetParam(
"dtype", One_hotDtypeValues[
int(dtype)])
3067 return Operator(
"broadcast_not_equal")
3098 return Operator(
"broadcast_greater")
3129 return Operator(
"broadcast_greater_equal")
3160 return Operator(
"broadcast_lesser")
3191 return Operator(
"broadcast_lesser_equal")
3312 static const char *Cast_storageStypeValues[] = {
3318 .
SetParam(
"stype", Cast_storageStypeValues[
int(stype)])
3769 const std::vector<Symbol>& data,
3773 uint32_t num_filter = 0,
3775 uint64_t workspace = 512) {
3776 static const char *UpSamplingSampleTypeValues[] = {
3780 static const char *UpSamplingMultiInputModeValues[] = {
3786 .
SetParam(
"sample_type", UpSamplingSampleTypeValues[
int(sample_type)])
3788 .
SetParam(
"num_filter", num_filter)
3789 .
SetParam(
"multi_input_mode", UpSamplingMultiInputModeValues[
int(multi_input_mode)])
3792 .CreateSymbol(symbol_name);
3866 bool fix_gamma =
true,
3867 bool use_global_stats =
false,
3868 bool output_mean_var =
false,
3870 bool cudnn_off =
false) {
3875 .
SetParam(
"use_global_stats", use_global_stats)
3876 .
SetParam(
"output_mean_var", output_mean_var)
3882 .
SetInput(
"moving_mean", moving_mean)
3883 .
SetInput(
"moving_var", moving_var)
3996 double constant_value = 0) {
3997 static const char *PadModeValues[] = {
4003 .
SetParam(
"mode", PadModeValues[
int(mode)])
4005 .
SetParam(
"constant_value", constant_value)
4052 const std::vector<Symbol>& data,
4059 .CreateSymbol(symbol_name);
4103 static const char *LeakyReLUActTypeValues[] = {
4110 .
SetParam(
"act_type", LeakyReLUActTypeValues[
int(act_type)])
4112 .
SetParam(
"lower_bound", lower_bound)
4113 .
SetParam(
"upper_bound", upper_bound)
4193 bool squeeze_axis =
false) {
4195 .
SetParam(
"num_outputs", num_outputs)
4197 .
SetParam(
"squeeze_axis", squeeze_axis)
4233 uint32_t dim2 = 0) {
4302 bool fix_gamma =
true,
4303 bool use_global_stats =
false,
4304 bool output_mean_var =
false) {
4309 .
SetParam(
"use_global_stats", use_global_stats)
4310 .
SetParam(
"output_mean_var", output_mean_var)
4357 return Operator(
"softmax_cross_entropy")
4391 return Operator(
"LinearRegressionOutput")
4392 .
SetParam(
"grad_scale", grad_scale)
4427 return Operator(
"MAERegressionOutput")
4428 .
SetParam(
"grad_scale", grad_scale)
4463 return Operator(
"LogisticRegressionOutput")
4464 .
SetParam(
"grad_scale", grad_scale)
4484 return Operator(
"IdentityAttachKLSparseReg")
4485 .
SetParam(
"sparseness_target", sparseness_target)
4529 .
SetParam(
"rescale_grad", rescale_grad)
4530 .
SetParam(
"clip_gradient", clip_gradient)
4584 .
SetParam(
"rescale_grad", rescale_grad)
4585 .
SetParam(
"clip_gradient", clip_gradient)
4630 .
SetParam(
"rescale_grad", rescale_grad)
4631 .
SetParam(
"clip_gradient", clip_gradient)
4695 .
SetParam(
"rescale_grad", rescale_grad)
4696 .
SetParam(
"clip_gradient", clip_gradient)
4728 .
SetParam(
"rescale_grad", rescale_grad)
4729 .
SetParam(
"clip_gradient", clip_gradient)
4762 return Operator(
"mp_sgd_mom_update")
4766 .
SetParam(
"rescale_grad", rescale_grad)
4767 .
SetParam(
"clip_gradient", clip_gradient)
4828 .
SetParam(
"rescale_grad", rescale_grad)
4829 .
SetParam(
"clip_gradient", clip_gradient)
4903 .
SetParam(
"rescale_grad", rescale_grad)
4904 .
SetParam(
"clip_gradient", clip_gradient)
4981 .
SetParam(
"rescale_grad", rescale_grad)
4982 .
SetParam(
"clip_gradient", clip_gradient)
4983 .
SetParam(
"clip_weights", clip_weights)
5049 return Operator(
"rmspropalex_update")
5055 .
SetParam(
"rescale_grad", rescale_grad)
5056 .
SetParam(
"clip_gradient", clip_gradient)
5057 .
SetParam(
"clip_weights", clip_weights)
5121 .
SetParam(
"rescale_grad", rescale_grad)
5122 .
SetParam(
"clip_gradient", clip_gradient)
5202 bool global_pool =
false,
5203 bool cudnn_off =
false,
5207 static const char *PoolingPoolTypeValues[] = {
5212 static const char *PoolingPoolingConventionValues[] = {
5218 .
SetParam(
"pool_type", PoolingPoolTypeValues[
int(pool_type)])
5219 .
SetParam(
"global_pool", global_pool)
5221 .
SetParam(
"pooling_convention", PoolingPoolingConventionValues[
int(pooling_convention)])
5279 uint32_t num_filter,
5285 uint32_t num_group = 1,
5286 uint64_t workspace = 512,
5287 bool no_bias =
true,
5289 bool cudnn_off =
false,
5291 static const char *DeconvolutionCudnnTuneValues[] = {
5294 "limited_workspace",
5297 static const char *DeconvolutionLayoutValues[] = {
5307 .
SetParam(
"num_filter", num_filter)
5312 .
SetParam(
"target_shape", target_shape)
5316 .
SetParam(
"cudnn_tune", DeconvolutionCudnnTuneValues[
int(cudnn_tune)])
5318 .
SetParam(
"layout", DeconvolutionLayoutValues[
int(layout)])
5355 static const char *ActivationActTypeValues[] = {
5362 .
SetParam(
"act_type", ActivationActTypeValues[
int(act_type)])
5487 uint32_t num_filter,
5491 uint32_t num_group = 1,
5492 uint64_t workspace = 1024,
5493 bool no_bias =
false,
5495 bool cudnn_off =
false,
5497 static const char *ConvolutionCudnnTuneValues[] = {
5500 "limited_workspace",
5503 static const char *ConvolutionLayoutValues[] = {
5513 .
SetParam(
"num_filter", num_filter)
5520 .
SetParam(
"cudnn_tune", ConvolutionCudnnTuneValues[
int(cudnn_tune)])
5522 .
SetParam(
"layout", ConvolutionLayoutValues[
int(layout)])
5579 static const char *DropoutModeValues[] = {
5585 .
SetParam(
"mode", DropoutModeValues[
int(mode)])
5634 static const char *SoftmaxActivationModeValues[] = {
5638 return Operator(
"SoftmaxActivation")
5639 .
SetParam(
"mode", SoftmaxActivationModeValues[
int(mode)])
5682 bool no_bias =
false,
5683 bool flatten =
true) {
5685 .
SetParam(
"num_hidden", num_hidden)
5779 static const char *GridGeneratorTransformTypeValues[] = {
5784 .
SetParam(
"transform_type", GridGeneratorTransformTypeValues[
int(transform_type)])
5785 .
SetParam(
"target_shape", target_shape)
5860 bool global_pool =
false,
5864 static const char *Pooling_v1PoolTypeValues[] = {
5869 static const char *Pooling_v1PoolingConventionValues[] = {
5875 .
SetParam(
"pool_type", Pooling_v1PoolTypeValues[
int(pool_type)])
5876 .
SetParam(
"global_pool", global_pool)
5877 .
SetParam(
"pooling_convention", Pooling_v1PoolingConventionValues[
int(pooling_convention)])
5913 uint32_t state_size,
5914 uint32_t num_layers,
5916 bool bidirectional =
false,
5918 bool state_outputs =
false) {
5919 static const char *RNNModeValues[] = {
5926 .
SetParam(
"state_size", state_size)
5927 .
SetParam(
"num_layers", num_layers)
5928 .
SetParam(
"mode", RNNModeValues[
int(mode)])
5929 .
SetParam(
"bidirectional", bidirectional)
5931 .
SetParam(
"state_outputs", state_outputs)
5933 .
SetInput(
"parameters", parameters)
5935 .
SetInput(
"state_cell", state_cell)
5999 uint32_t num_filter,
6003 uint32_t num_group = 1,
6004 uint64_t workspace = 1024,
6005 bool no_bias =
false,
6007 bool cudnn_off =
false,
6009 static const char *Convolution_v1CudnnTuneValues[] = {
6012 "limited_workspace",
6015 static const char *Convolution_v1LayoutValues[] = {
6024 .
SetParam(
"num_filter", num_filter)
6031 .
SetParam(
"cudnn_tune", Convolution_v1CudnnTuneValues[
int(cudnn_tune)])
6033 .
SetParam(
"layout", Convolution_v1LayoutValues[
int(layout)])
6061 const std::vector<Symbol>& data,
6065 bool center_crop =
false) {
6070 .
SetParam(
"center_crop", center_crop)
6072 .CreateSymbol(symbol_name);
6154 bool use_sequence_length =
false,
6157 .
SetParam(
"use_sequence_length", use_sequence_length)
6160 .
SetInput(
"sequence_length", sequence_length)
6192 static const char *SpatialTransformerTransformTypeValues[] = {
6195 static const char *SpatialTransformerSamplerTypeValues[] = {
6198 return Operator(
"SpatialTransformer")
6199 .
SetParam(
"transform_type", SpatialTransformerTransformTypeValues[
int(transform_type)])
6200 .
SetParam(
"sampler_type", SpatialTransformerSamplerTypeValues[
int(sampler_type)])
6201 .
SetParam(
"target_shape", target_shape)
6265 bool use_sequence_length =
false,
6268 .
SetParam(
"use_sequence_length", use_sequence_length)
6271 .
SetInput(
"sequence_length", sequence_length)
6383 bool multi_output =
false,
6384 bool use_ignore =
false,
6385 bool preserve_shape =
false,
6387 bool out_grad =
false,
6389 static const char *SoftmaxOutputNormalizationValues[] = {
6395 .
SetParam(
"grad_scale", grad_scale)
6396 .
SetParam(
"ignore_label", ignore_label)
6397 .
SetParam(
"multi_output", multi_output)
6398 .
SetParam(
"use_ignore", use_ignore)
6399 .
SetParam(
"preserve_shape", preserve_shape)
6400 .
SetParam(
"normalization", SoftmaxOutputNormalizationValues[
int(normalization)])
6402 .
SetParam(
"smooth_alpha", smooth_alpha)
6447 bool multi_output =
false,
6448 bool use_ignore =
false,
6449 bool preserve_shape =
false,
6451 bool out_grad =
false,
6453 static const char *SoftmaxNormalizationValues[] = {
6459 .
SetParam(
"grad_scale", grad_scale)
6460 .
SetParam(
"ignore_label", ignore_label)
6461 .
SetParam(
"multi_output", multi_output)
6462 .
SetParam(
"use_ignore", use_ignore)
6463 .
SetParam(
"preserve_shape", preserve_shape)
6464 .
SetParam(
"normalization", SoftmaxNormalizationValues[
int(normalization)])
6466 .
SetParam(
"smooth_alpha", smooth_alpha)
6622 .
SetParam(
"pooled_size", pooled_size)
6623 .
SetParam(
"spatial_scale", spatial_scale)
6703 static const char *L2NormalizationModeValues[] = {
6710 .
SetParam(
"mode", L2NormalizationModeValues[
int(mode)])
6764 static const char *MakeLossNormalizationValues[] = {
6770 .
SetParam(
"grad_scale", grad_scale)
6771 .
SetParam(
"valid_thresh", valid_thresh)
6772 .
SetParam(
"normalization", MakeLossNormalizationValues[
int(normalization)])
6796 mx_float regularization_coefficient = 1,
6797 bool use_linear =
false) {
6800 .
SetParam(
"regularization_coefficient", regularization_coefficient)
6801 .
SetParam(
"use_linear", use_linear)
6900 uint32_t kernel_size = 1,
6901 uint32_t max_displacement = 1,
6902 uint32_t stride1 = 1,
6903 uint32_t stride2 = 1,
6904 uint32_t pad_size = 0,
6905 bool is_multiply =
true) {
6907 .
SetParam(
"kernel_size", kernel_size)
6908 .
SetParam(
"max_displacement", max_displacement)
6912 .
SetParam(
"is_multiply", is_multiply)
6999 bool use_sequence_length =
false,
7003 .
SetParam(
"use_sequence_length", use_sequence_length)
7007 .
SetInput(
"sequence_length", sequence_length)
7022 return Operator(
"choose_element_0index")
7041 return Operator(
"fill_element_0index")
7108 const std::string& op_type) {
7168 return Operator(
"broadcast_maximum")
7199 return Operator(
"broadcast_minimum")
7319 bool keep_highest =
false) {
7323 .
SetParam(
"target_shape", target_shape)
7324 .
SetParam(
"keep_highest", keep_highest)
7529 dmlc::optional<int> end) {
7616 dmlc::optional<int> axis = dmlc::optional<int>()) {
7999 dmlc::optional<int> axis = dmlc::optional<int>(),
8000 bool keepdims =
false) {
8039 dmlc::optional<int> axis = dmlc::optional<int>(),
8040 bool keepdims =
false) {
8121 dmlc::optional<int> axis = dmlc::optional<int>(),
8122 bool keepdims =
false) {
8172 bool transpose_a =
false,
8173 bool transpose_b =
false) {
8175 .
SetParam(
"transpose_a", transpose_a)
8176 .
SetParam(
"transpose_b", transpose_b)
8205 bool transpose_a =
false,
8206 bool transpose_b =
false) {
8208 .
SetParam(
"transpose_a", transpose_a)
8209 .
SetParam(
"transpose_b", transpose_b)
8364 static const char *CastDtypeValues[] = {
8372 .
SetParam(
"dtype", CastDtypeValues[
int(dtype)])
8961 bool keepdims =
false,
8962 bool exclude =
false) {
8996 bool keepdims =
false,
8997 bool exclude =
false) {
9031 bool keepdims =
false,
9032 bool exclude =
false) {
9068 bool keepdims =
false,
9069 bool exclude =
false) {
9105 bool keepdims =
false,
9106 bool exclude =
false) {
9140 bool keepdims =
false,
9141 bool exclude =
false) {
9175 bool keepdims =
false,
9176 bool exclude =
false) {
9329 dmlc::optional<int> axis = dmlc::optional<int>(-1),
9332 bool is_ascend =
false) {
9333 static const char *TopkRetTypValues[] = {
9342 .
SetParam(
"ret_typ", TopkRetTypValues[
int(ret_typ)])
9380 dmlc::optional<int> axis = dmlc::optional<int>(-1),
9381 bool is_ascend =
true) {
9419 dmlc::optional<int> axis = dmlc::optional<int>(-1),
9420 bool is_ascend =
true) {
9570 static const char *EmbeddingDtypeValues[] = {
9579 .
SetParam(
"output_dim", output_dim)
9580 .
SetParam(
"dtype", EmbeddingDtypeValues[
int(dtype)])
9628 static const char *TakeModeValues[] = {
9635 .
SetParam(
"mode", TakeModeValues[
int(mode)])
9721 double on_value = 1,
9722 double off_value = 0,
9724 static const char *One_hotDtypeValues[] = {
9735 .
SetParam(
"dtype", One_hotDtypeValues[
int(dtype)])
9871 return Operator(
"broadcast_not_equal")
9900 return Operator(
"broadcast_greater")
9929 return Operator(
"broadcast_greater_equal")
9958 return Operator(
"broadcast_lesser")
9987 return Operator(
"broadcast_lesser_equal")
10094 static const char *Cast_storageStypeValues[] = {
10100 .
SetParam(
"stype", Cast_storageStypeValues[
int(stype)])
10506 uint32_t num_filter = 0,
10508 uint64_t workspace = 512) {
10509 static const char *UpSamplingSampleTypeValues[] = {
10513 static const char *UpSamplingMultiInputModeValues[] = {
10519 .
SetParam(
"sample_type", UpSamplingSampleTypeValues[
int(sample_type)])
10521 .
SetParam(
"num_filter", num_filter)
10522 .
SetParam(
"multi_input_mode", UpSamplingMultiInputModeValues[
int(multi_input_mode)])
10595 double eps = 0.001,
10597 bool fix_gamma =
true,
10598 bool use_global_stats =
false,
10599 bool output_mean_var =
false,
10601 bool cudnn_off =
false) {
10606 .
SetParam(
"use_global_stats", use_global_stats)
10607 .
SetParam(
"output_mean_var", output_mean_var)
10613 .
SetInput(
"moving_mean", moving_mean)
10614 .
SetInput(
"moving_var", moving_var)
10716 double constant_value = 0) {
10717 static const char *PadModeValues[] = {
10723 .
SetParam(
"mode", PadModeValues[
int(mode)])
10725 .
SetParam(
"constant_value", constant_value)
10810 static const char *LeakyReLUActTypeValues[] = {
10817 .
SetParam(
"act_type", LeakyReLUActTypeValues[
int(act_type)])
10819 .
SetParam(
"lower_bound", lower_bound)
10820 .
SetParam(
"upper_bound", upper_bound)
10898 bool squeeze_axis =
false) {
10900 .
SetParam(
"num_outputs", num_outputs)
10902 .
SetParam(
"squeeze_axis", squeeze_axis)
10936 uint32_t dim2 = 0) {
11003 bool fix_gamma =
true,
11004 bool use_global_stats =
false,
11005 bool output_mean_var =
false) {
11010 .
SetParam(
"use_global_stats", use_global_stats)
11011 .
SetParam(
"output_mean_var", output_mean_var)
11056 return Operator(
"softmax_cross_entropy")
11088 return Operator(
"LinearRegressionOutput")
11089 .
SetParam(
"grad_scale", grad_scale)
11122 return Operator(
"MAERegressionOutput")
11123 .
SetParam(
"grad_scale", grad_scale)
11156 return Operator(
"LogisticRegressionOutput")
11157 .
SetParam(
"grad_scale", grad_scale)
11175 return Operator(
"IdentityAttachKLSparseReg")
11176 .
SetParam(
"sparseness_target", sparseness_target)
11218 .
SetParam(
"rescale_grad", rescale_grad)
11219 .
SetParam(
"clip_gradient", clip_gradient)
11271 .
SetParam(
"rescale_grad", rescale_grad)
11272 .
SetParam(
"clip_gradient", clip_gradient)
11315 .
SetParam(
"rescale_grad", rescale_grad)
11316 .
SetParam(
"clip_gradient", clip_gradient)
11378 .
SetParam(
"rescale_grad", rescale_grad)
11379 .
SetParam(
"clip_gradient", clip_gradient)
11409 .
SetParam(
"rescale_grad", rescale_grad)
11410 .
SetParam(
"clip_gradient", clip_gradient)
11441 return Operator(
"mp_sgd_mom_update")
11445 .
SetParam(
"rescale_grad", rescale_grad)
11446 .
SetParam(
"clip_gradient", clip_gradient)
11505 .
SetParam(
"rescale_grad", rescale_grad)
11506 .
SetParam(
"clip_gradient", clip_gradient)
11578 .
SetParam(
"rescale_grad", rescale_grad)
11579 .
SetParam(
"clip_gradient", clip_gradient)
11654 .
SetParam(
"rescale_grad", rescale_grad)
11655 .
SetParam(
"clip_gradient", clip_gradient)
11656 .
SetParam(
"clip_weights", clip_weights)
11720 return Operator(
"rmspropalex_update")
11726 .
SetParam(
"rescale_grad", rescale_grad)
11727 .
SetParam(
"clip_gradient", clip_gradient)
11728 .
SetParam(
"clip_weights", clip_weights)
11790 .
SetParam(
"rescale_grad", rescale_grad)
11791 .
SetParam(
"clip_gradient", clip_gradient)
11854 bool global_pool =
false,
11855 bool cudnn_off =
false,
11859 static const char *PoolingPoolTypeValues[] = {
11864 static const char *PoolingPoolingConventionValues[] = {
11870 .
SetParam(
"pool_type", PoolingPoolTypeValues[
int(pool_type)])
11871 .
SetParam(
"global_pool", global_pool)
11873 .
SetParam(
"pooling_convention", PoolingPoolingConventionValues[
int(pooling_convention)])
11909 uint32_t num_filter,
11915 uint32_t num_group = 1,
11916 uint64_t workspace = 512,
11917 bool no_bias =
true,
11919 bool cudnn_off =
false,
11921 static const char *DeconvolutionCudnnTuneValues[] = {
11924 "limited_workspace",
11927 static const char *DeconvolutionLayoutValues[] = {
11937 .
SetParam(
"num_filter", num_filter)
11942 .
SetParam(
"target_shape", target_shape)
11946 .
SetParam(
"cudnn_tune", DeconvolutionCudnnTuneValues[
int(cudnn_tune)])
11948 .
SetParam(
"layout", DeconvolutionLayoutValues[
int(layout)])
11974 static const char *ActivationActTypeValues[] = {
11981 .
SetParam(
"act_type", ActivationActTypeValues[
int(act_type)])
12083 uint32_t num_filter,
12087 uint32_t num_group = 1,
12088 uint64_t workspace = 1024,
12089 bool no_bias =
false,
12091 bool cudnn_off =
false,
12093 static const char *ConvolutionCudnnTuneValues[] = {
12096 "limited_workspace",
12099 static const char *ConvolutionLayoutValues[] = {
12109 .
SetParam(
"num_filter", num_filter)
12116 .
SetParam(
"cudnn_tune", ConvolutionCudnnTuneValues[
int(cudnn_tune)])
12118 .
SetParam(
"layout", ConvolutionLayoutValues[
int(layout)])
12166 static const char *DropoutModeValues[] = {
12172 .
SetParam(
"mode", DropoutModeValues[
int(mode)])
12211 static const char *SoftmaxActivationModeValues[] = {
12215 return Operator(
"SoftmaxActivation")
12216 .
SetParam(
"mode", SoftmaxActivationModeValues[
int(mode)])
12257 bool no_bias =
false,
12258 bool flatten =
true) {
12260 .
SetParam(
"num_hidden", num_hidden)
12342 static const char *GridGeneratorTransformTypeValues[] = {
12347 .
SetParam(
"transform_type", GridGeneratorTransformTypeValues[
int(transform_type)])
12348 .
SetParam(
"target_shape", target_shape)
12406 bool global_pool =
false,
12410 static const char *Pooling_v1PoolTypeValues[] = {
12415 static const char *Pooling_v1PoolingConventionValues[] = {
12421 .
SetParam(
"pool_type", Pooling_v1PoolTypeValues[
int(pool_type)])
12422 .
SetParam(
"global_pool", global_pool)
12423 .
SetParam(
"pooling_convention", Pooling_v1PoolingConventionValues[
int(pooling_convention)])
12448 uint32_t state_size,
12449 uint32_t num_layers,
12451 bool bidirectional =
false,
12453 bool state_outputs =
false) {
12454 static const char *RNNModeValues[] = {
12461 .
SetParam(
"state_size", state_size)
12462 .
SetParam(
"num_layers", num_layers)
12463 .
SetParam(
"mode", RNNModeValues[
int(mode)])
12464 .
SetParam(
"bidirectional", bidirectional)
12466 .
SetParam(
"state_outputs", state_outputs)
12468 .
SetInput(
"parameters", parameters)
12470 .
SetInput(
"state_cell", state_cell)
12505 uint32_t num_filter,
12509 uint32_t num_group = 1,
12510 uint64_t workspace = 1024,
12511 bool no_bias =
false,
12513 bool cudnn_off =
false,
12515 static const char *Convolution_v1CudnnTuneValues[] = {
12518 "limited_workspace",
12521 static const char *Convolution_v1LayoutValues[] = {
12530 .
SetParam(
"num_filter", num_filter)
12537 .
SetParam(
"cudnn_tune", Convolution_v1CudnnTuneValues[
int(cudnn_tune)])
12539 .
SetParam(
"layout", Convolution_v1LayoutValues[
int(layout)])
12569 bool center_crop =
false) {
12574 .
SetParam(
"center_crop", center_crop)
12656 bool use_sequence_length =
false,
12658 return Operator(
"SequenceReverse")
12659 .
SetParam(
"use_sequence_length", use_sequence_length)
12662 .
SetInput(
"sequence_length", sequence_length)
12680 static const char *SpatialTransformerTransformTypeValues[] = {
12683 static const char *SpatialTransformerSamplerTypeValues[] = {
12686 return Operator(
"SpatialTransformer")
12687 .
SetParam(
"transform_type", SpatialTransformerTransformTypeValues[
int(transform_type)])
12688 .
SetParam(
"sampler_type", SpatialTransformerSamplerTypeValues[
int(sampler_type)])
12689 .
SetParam(
"target_shape", target_shape)
12751 bool use_sequence_length =
false,
12754 .
SetParam(
"use_sequence_length", use_sequence_length)
12757 .
SetInput(
"sequence_length", sequence_length)
12859 bool multi_output =
false,
12860 bool use_ignore =
false,
12861 bool preserve_shape =
false,
12863 bool out_grad =
false,
12865 static const char *SoftmaxOutputNormalizationValues[] = {
12871 .
SetParam(
"grad_scale", grad_scale)
12872 .
SetParam(
"ignore_label", ignore_label)
12873 .
SetParam(
"multi_output", multi_output)
12874 .
SetParam(
"use_ignore", use_ignore)
12875 .
SetParam(
"preserve_shape", preserve_shape)
12876 .
SetParam(
"normalization", SoftmaxOutputNormalizationValues[
int(normalization)])
12878 .
SetParam(
"smooth_alpha", smooth_alpha)
12913 bool multi_output =
false,
12914 bool use_ignore =
false,
12915 bool preserve_shape =
false,
12917 bool out_grad =
false,
12919 static const char *SoftmaxNormalizationValues[] = {
12925 .
SetParam(
"grad_scale", grad_scale)
12926 .
SetParam(
"ignore_label", ignore_label)
12927 .
SetParam(
"multi_output", multi_output)
12928 .
SetParam(
"use_ignore", use_ignore)
12929 .
SetParam(
"preserve_shape", preserve_shape)
12930 .
SetParam(
"normalization", SoftmaxNormalizationValues[
int(normalization)])
12932 .
SetParam(
"smooth_alpha", smooth_alpha)
13018 return Operator(
"BilinearSampler")
13084 .
SetParam(
"pooled_size", pooled_size)
13085 .
SetParam(
"spatial_scale", spatial_scale)
13155 static const char *L2NormalizationModeValues[] = {
13160 return Operator(
"L2Normalization")
13162 .
SetParam(
"mode", L2NormalizationModeValues[
int(mode)])
13204 static const char *MakeLossNormalizationValues[] = {
13210 .
SetParam(
"grad_scale", grad_scale)
13211 .
SetParam(
"valid_thresh", valid_thresh)
13212 .
SetParam(
"normalization", MakeLossNormalizationValues[
int(normalization)])
13234 mx_float regularization_coefficient = 1,
13235 bool use_linear =
false) {
13238 .
SetParam(
"regularization_coefficient", regularization_coefficient)
13239 .
SetParam(
"use_linear", use_linear)
13334 uint32_t kernel_size = 1,
13335 uint32_t max_displacement = 1,
13336 uint32_t stride1 = 1,
13337 uint32_t stride2 = 1,
13338 uint32_t pad_size = 0,
13339 bool is_multiply =
true) {
13341 .
SetParam(
"kernel_size", kernel_size)
13342 .
SetParam(
"max_displacement", max_displacement)
13346 .
SetParam(
"is_multiply", is_multiply)
13431 bool use_sequence_length =
false,
13435 .
SetParam(
"use_sequence_length", use_sequence_length)
13439 .
SetInput(
"sequence_length", sequence_length)
13452 return Operator(
"choose_element_0index")
13469 return Operator(
"fill_element_0index")
13478 #endif // MXNET_CPP_OP_H_ Symbol Convolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, ConvolutionCudnnTune cudnn_tune=ConvolutionCudnnTune::kNone, bool cudnn_off=false, ConvolutionLayout layout=ConvolutionLayout::kNone)
Definition: op.h:5482
Symbol fix(const std::string &symbol_name, Symbol data)
Definition: op.h:1692
Symbol Crop(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, Shape offset=Shape(0, 0), Shape h_w=Shape(0, 0), bool center_crop=false)
Definition: op.h:6060
Symbol min(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2290
Symbol broadcast_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:894
Symbol arcsin(const std::string &symbol_name, Symbol data)
Definition: op.h:3422
Symbol FullyConnected(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, int num_hidden, bool no_bias=false, bool flatten=true)
Definition: op.h:5677
Symbol arccosh(const std::string &symbol_name, Symbol data)
Definition: op.h:3639
Symbol arctan(const std::string &symbol_name, Symbol data)
Definition: op.h:3474
Symbol SwapAxis(const std::string &symbol_name, Symbol data, uint32_t dim1=0, uint32_t dim2=0)
Definition: op.h:4230
Symbol cast_storage(const std::string &symbol_name, Symbol data, Cast_storageStype stype)
Definition: op.h:3309
Symbol nansum(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2177
Symbol add_n(const std::string &symbol_name, const std::vector< Symbol > &args)
Definition: op.h:985
Symbol log1p(const std::string &symbol_name, Symbol data)
Definition: op.h:1939
SoftmaxActivationMode
Definition: op.h:5593
Symbol mp_sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol weight32, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4717
Symbol SpatialTransformer(const std::string &symbol_name, Symbol data, Symbol loc, SpatialTransformerTransformType transform_type, SpatialTransformerSamplerType sampler_type, Shape target_shape=Shape(0, 0))
Definition: op.h:6186
Symbol slice(const std::string &symbol_name, Symbol data, Shape begin, Shape end, Shape step=Shape())
Definition: op.h:478
Symbol exp(const std::string &symbol_name, Symbol data)
Definition: op.h:1851
Symbol transpose(const std::string &symbol_name, Symbol data, Shape axes=Shape())
Definition: op.h:390
Symbol clip(const std::string &symbol_name, Symbol data, mx_float a_min, mx_float a_max)
Definition: op.h:570
Symbol elemwise_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2656
Symbol ROIPooling(const std::string &symbol_name, Symbol data, Symbol rois, Shape pooled_size, mx_float spatial_scale)
Definition: op.h:6616
Symbol broadcast_div(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:925
Symbol nanprod(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2216
Convolution_v1Layout
Definition: op.h:5958
Symbol argmin(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1065
Symbol SequenceReverse(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, int axis=0)
Definition: op.h:6151
Symbol mp_sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, Symbol weight32, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4752
Symbol broadcast_lesser(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3157
Symbol fill_element_0index(const std::string &symbol_name, Symbol lhs, Symbol mhs, Symbol rhs)
Definition: op.h:7037
Symbol Convolution_v1(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), uint32_t num_group=1, uint64_t workspace=1024, bool no_bias=false, Convolution_v1CudnnTune cudnn_tune=Convolution_v1CudnnTune::kNone, bool cudnn_off=false, Convolution_v1Layout layout=Convolution_v1Layout::kNone)
Definition: op.h:5994
Symbol broadcast_not_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3064
TakeMode
Definition: op.h:2752
Symbol Embedding(const std::string &symbol_name, Symbol data, Symbol weight, int input_dim, int output_dim, EmbeddingDtype dtype=EmbeddingDtype::kFloat32)
Definition: op.h:2726
Symbol SequenceLast(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, int axis=0)
Definition: op.h:6262
Symbol ftrl_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol z, Symbol n, mx_float lr, mx_float lamda1=0.01, mx_float beta=1, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:5105
Symbol reciprocal(const std::string &symbol_name, Symbol data)
Definition: op.h:1472
TopkRetTyp
Definition: op.h:2417
Symbol RNN(const std::string &symbol_name, Symbol data, Symbol parameters, Symbol state, Symbol state_cell, uint32_t state_size, uint32_t num_layers, RNNMode mode, bool bidirectional=false, mx_float p=0, bool state_outputs=false)
Definition: op.h:5908
namespace of mxnet
Definition: base.h:127
Symbol reshape_like(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:1381
Pooling_v1PoolingConvention
Definition: op.h:5800
Symbol broadcast_lesser_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3188
Operator & SetInput(const std::string &name, Symbol symbol)
add an input symbol
Symbol InstanceNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001)
Definition: op.h:5744
Symbol sign(const std::string &symbol_name, Symbol data)
Definition: op.h:1524
GridGeneratorTransformType
Definition: op.h:5760
Cast_storageStype
Definition: op.h:3260
dynamic shape class that can hold shape of arbirary dimension
Definition: shape.h:43
Symbol ones_like(const std::string &symbol_name, Symbol data)
Definition: op.h:793
RNNMode
Definition: op.h:5886
PadMode
Definition: op.h:3890
Symbol smooth_l1(const std::string &symbol_name, Symbol data, mx_float scalar)
Definition: op.h:3249
Symbol where(const std::string &symbol_name, Symbol condition, Symbol x, Symbol y)
Definition: op.h:3213
Symbol Dropout(const std::string &symbol_name, Symbol data, mx_float p=0.5, DropoutMode mode=DropoutMode::kTraining)
Definition: op.h:5575
Symbol expm1(const std::string &symbol_name, Symbol data)
Definition: op.h:1963
Symbol elemwise_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2585
PoolingPoolType
Definition: op.h:5132
Symbol relu(const std::string &symbol_name, Symbol data)
Definition: op.h:1269
Symbol reverse(const std::string &symbol_name, Symbol data, Shape axis)
Definition: op.h:703
Symbol rsqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1777
Symbol Pooling(const std::string &symbol_name, Symbol data, Shape kernel, PoolingPoolType pool_type, bool global_pool=false, bool cudnn_off=false, PoolingPoolingConvention pooling_convention=PoolingPoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:5198
Symbol Pooling_v1(const std::string &symbol_name, Symbol data, Shape kernel, Pooling_v1PoolType pool_type, bool global_pool=false, Pooling_v1PoolingConvention pooling_convention=Pooling_v1PoolingConvention::kValid, Shape stride=Shape(), Shape pad=Shape())
Definition: op.h:5856
SpatialTransformerTransformType
Definition: op.h:6166
ActivationActType
Definition: op.h:5327
Symbol sqrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1751
Symbol Softmax(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxNormalization normalization=SoftmaxNormalization::kNull, bool out_grad=false, mx_float smooth_alpha=0)
Definition: op.h:6443
Symbol rint(const std::string &symbol_name, Symbol data)
Definition: op.h:1580
Symbol IdentityAttachKLSparseReg(const std::string &symbol_name, Symbol data, mx_float sparseness_target=0.1, mx_float penalty=0.001, mx_float momentum=0.9)
Definition: op.h:4479
Symbol sinh(const std::string &symbol_name, Symbol data)
Definition: op.h:3549
Symbol scatter_nd(const std::string &symbol_name, Symbol data, Symbol indices, Shape shape)
Definition: op.h:3000
Symbol broadcast_greater_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3126
Symbol LRN(const std::string &symbol_name, Symbol data, uint32_t nsize, mx_float alpha=0.0001, mx_float beta=0.75, mx_float knorm=2)
Definition: op.h:6834
Symbol sgd_mom_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4682
Symbol max(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2253
Symbol arcsinh(const std::string &symbol_name, Symbol data)
Definition: op.h:3619
Symbol sgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4620
Symbol MAERegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:4423
Symbol SliceChannel(const std::string &symbol_name, Symbol data, int num_outputs, int axis=1, bool squeeze_axis=false)
Definition: op.h:4189
PoolingPoolingConvention
Definition: op.h:5140
Symbol broadcast_minimum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:180
Symbol broadcast_maximum(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:147
Symbol Cast(const std::string &symbol_name, Symbol data, CastDtype dtype)
Definition: op.h:1419
DeconvolutionLayout
Definition: op.h:5239
Symbol trunc(const std::string &symbol_name, Symbol data)
Definition: op.h:1665
Pooling_v1PoolType
Definition: op.h:5792
Symbol round(const std::string &symbol_name, Symbol data)
Definition: op.h:1550
Symbol log_softmax(const std::string &symbol_name, Symbol data, int axis=-1)
Definition: op.h:3729
Symbol khatri_rao(const std::string &symbol_name, const std::vector< Symbol > &args)
Definition: op.h:62
Symbol cos(const std::string &symbol_name, Symbol data)
Definition: op.h:3367
Symbol L2Normalization(const std::string &symbol_name, Symbol data, mx_float eps=1e-10, L2NormalizationMode mode=L2NormalizationMode::kInstance)
Definition: op.h:6699
Symbol Correlation(const std::string &symbol_name, Symbol data1, Symbol data2, uint32_t kernel_size=1, uint32_t max_displacement=1, uint32_t stride1=1, uint32_t stride2=1, uint32_t pad_size=0, bool is_multiply=true)
Definition: op.h:6897
Symbol zeros_like(const std::string &symbol_name, Symbol data)
Definition: op.h:769
EmbeddingDtype
Definition: op.h:2667
Symbol batch_dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false)
Definition: op.h:1238
Symbol broadcast_mod(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:956
Symbol cbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1801
Symbol prod(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2138
operator helper functions
Symbol mean(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2101
Symbol tanh(const std::string &symbol_name, Symbol data)
Definition: op.h:3596
Symbol broadcast_to(const std::string &symbol_name, Symbol data, Shape shape=Shape())
Definition: op.h:2371
Symbol elemwise_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2609
DropoutMode
Definition: op.h:5531
Symbol MakeLoss(const std::string &symbol_name, Symbol data, mx_float grad_scale=1, mx_float valid_thresh=0, MakeLossNormalization normalization=MakeLossNormalization::kNull)
Definition: op.h:6759
Symbol log(const std::string &symbol_name, Symbol data)
Definition: op.h:1872
Symbol sigmoid(const std::string &symbol_name, Symbol data)
Definition: op.h:1291
CastDtype
Definition: op.h:1392
ConvolutionLayout
Definition: op.h:5379
Symbol LogisticRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:4459
Symbol gamma(const std::string &symbol_name, Symbol data)
Definition: op.h:1981
Symbol sin(const std::string &symbol_name, Symbol data)
Definition: op.h:3343
UpSamplingMultiInputMode
Definition: op.h:3748
Symbol CreateSymbol(const std::string &name="")
create a Symbol from the current operator
Symbol elemwise_mul(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:2636
SpatialTransformerSamplerType
Definition: op.h:6172
Symbol Pad(const std::string &symbol_name, Symbol data, PadMode mode, Shape pad_width, double constant_value=0)
Definition: op.h:3992
Symbol square(const std::string &symbol_name, Symbol data)
Definition: op.h:1722
One_hotDtype
Definition: op.h:2854
UpSamplingSampleType
Definition: op.h:3740
Symbol norm(const std::string &symbol_name, Symbol data)
Definition: op.h:2405
Symbol rmsprop_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, mx_float lr, mx_float gamma1=0.95, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:4965
Symbol LeakyReLU(const std::string &symbol_name, Symbol data, LeakyReLUActType act_type=LeakyReLUActType::kLeaky, mx_float slope=0.25, mx_float lower_bound=0.125, mx_float upper_bound=0.334)
Definition: op.h:4097
Symbol make_loss(const std::string &symbol_name, Symbol data)
Definition: op.h:1367
Symbol SoftmaxActivation(const std::string &symbol_name, Symbol data, SoftmaxActivationMode mode=SoftmaxActivationMode::kInstance)
Definition: op.h:5631
Symbol broadcast_equal(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3033
Symbol Deconvolution(const std::string &symbol_name, Symbol data, Symbol weight, Symbol bias, Shape kernel, uint32_t num_filter, Shape stride=Shape(), Shape dilate=Shape(), Shape pad=Shape(), Shape adj=Shape(), Shape target_shape=Shape(), uint32_t num_group=1, uint64_t workspace=512, bool no_bias=true, DeconvolutionCudnnTune cudnn_tune=DeconvolutionCudnnTune::kNone, bool cudnn_off=false, DeconvolutionLayout layout=DeconvolutionLayout::kNone)
Definition: op.h:5274
Symbol broadcast_add(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:827
Symbol adam_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mean, Symbol var, mx_float lr, mx_float beta1=0.9, mx_float beta2=0.999, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4885
Operator & SetParam(const std::string &name, const T &value)
set config parameters
Definition: operator.h:58
Symbol tan(const std::string &symbol_name, Symbol data)
Definition: op.h:3394
Convolution_v1CudnnTune
Definition: op.h:5948
Symbol repeat(const std::string &symbol_name, Symbol data, int repeats, dmlc::optional< int > axis=dmlc::optional< int >())
Definition: op.h:615
Symbol slice_axis(const std::string &symbol_name, Symbol data, int axis, int begin, dmlc::optional< int > end)
Definition: op.h:523
Symbol expand_dims(const std::string &symbol_name, Symbol data, int axis)
Definition: op.h:414
Symbol arctanh(const std::string &symbol_name, Symbol data)
Definition: op.h:3662
Symbol softmax_cross_entropy(const std::string &symbol_name, Symbol data, Symbol label)
Definition: op.h:4354
Symbol pick(const std::string &symbol_name, Symbol data, Symbol index, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1150
Symbol broadcast_axis(const std::string &symbol_name, Symbol data, Shape axis=Shape(), Shape size=Shape())
Definition: op.h:2332
Symbol abs(const std::string &symbol_name, Symbol data)
Definition: op.h:1498
Symbol cosh(const std::string &symbol_name, Symbol data)
Definition: op.h:3571
Symbol sort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Definition: op.h:2518
Symbol gather_nd(const std::string &symbol_name, Symbol data, Symbol indices)
Definition: op.h:2955
Symbol BilinearSampler(const std::string &symbol_name, Symbol data, Symbol grid)
Definition: op.h:6551
Symbol Custom(const std::string &symbol_name, const std::vector< Symbol > &data, const std::string &op_type)
Definition: op.h:84
Symbol broadcast_hypot(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:219
Symbol BatchNorm_v1(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, mx_float eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false)
Definition: op.h:4296
Symbol UpSampling(const std::string &symbol_name, const std::vector< Symbol > &data, uint32_t scale, UpSamplingSampleType sample_type, int num_args, uint32_t num_filter=0, UpSamplingMultiInputMode multi_input_mode=UpSamplingMultiInputMode::kConcat, uint64_t workspace=512)
Definition: op.h:3768
Symbol Activation(const std::string &symbol_name, Symbol data, ActivationActType act_type)
Definition: op.h:5352
float mx_float
manually define float
Definition: c_api.h:60
Symbol SVMOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float margin=1, mx_float regularization_coefficient=1, bool use_linear=false)
Definition: op.h:6792
Symbol radians(const std::string &symbol_name, Symbol data)
Definition: op.h:3524
Symbol Concat(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int dim=1)
Definition: op.h:4051
Symbol ftml_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol d, Symbol v, Symbol z, mx_float lr, mx_float beta1=0.9, mx_float beta2=0.999, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4809
L2NormalizationMode
Definition: op.h:6631
Symbol SequenceMask(const std::string &symbol_name, Symbol data, Symbol sequence_length, bool use_sequence_length=false, mx_float value=0, int axis=0)
Definition: op.h:6996
Symbol stack(const std::string &symbol_name, const std::vector< Symbol > &data, int num_args, int axis=0)
Definition: op.h:735
Symbol floor(const std::string &symbol_name, Symbol data)
Definition: op.h:1636
Symbol broadcast_sub(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:863
Symbol take(const std::string &symbol_name, Symbol a, Symbol indices, int axis=0, TakeMode mode=TakeMode::kClip)
Definition: op.h:2797
Symbol ceil(const std::string &symbol_name, Symbol data)
Definition: op.h:1608
Symbol gammaln(const std::string &symbol_name, Symbol data)
Definition: op.h:1999
Symbol tile(const std::string &symbol_name, Symbol data, Shape reps)
Definition: op.h:671
Symbol signum_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol mom, mx_float lr, mx_float momentum=0, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float wd_lh=0)
Definition: op.h:4570
Symbol rmspropalex_update(const std::string &symbol_name, Symbol weight, Symbol grad, Symbol n, Symbol g, Symbol delta, mx_float lr, mx_float gamma1=0.95, mx_float gamma2=0.9, mx_float epsilon=1e-08, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1, mx_float clip_weights=-1)
Definition: op.h:5035
Symbol argsort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=true)
Definition: op.h:2559
SoftmaxNormalization
Definition: op.h:6410
Symbol softmax(const std::string &symbol_name, Symbol data, int axis=-1)
Definition: op.h:3698
DeconvolutionCudnnTune
Definition: op.h:5230
ConvolutionCudnnTune
Definition: op.h:5369
Symbol broadcast_greater(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:3095
Symbol BatchNorm(const std::string &symbol_name, Symbol data, Symbol gamma, Symbol beta, Symbol moving_mean, Symbol moving_var, double eps=0.001, mx_float momentum=0.9, bool fix_gamma=true, bool use_global_stats=false, bool output_mean_var=false, int axis=1, bool cudnn_off=false)
Definition: op.h:3858
Symbol rcbrt(const std::string &symbol_name, Symbol data)
Definition: op.h:1825
Symbol topk(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), int k=1, TopkRetTyp ret_typ=TopkRetTyp::kIndices, bool is_ascend=false)
Definition: op.h:2465
Symbol signsgd_update(const std::string &symbol_name, Symbol weight, Symbol grad, mx_float lr, mx_float wd=0, mx_float rescale_grad=1, mx_float clip_gradient=-1)
Definition: op.h:4519
Symbol broadcast_power(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:114
SoftmaxOutputNormalization
Definition: op.h:6277
Symbol SoftmaxOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1, mx_float ignore_label=-1, bool multi_output=false, bool use_ignore=false, bool preserve_shape=false, SoftmaxOutputNormalization normalization=SoftmaxOutputNormalization::kNull, bool out_grad=false, mx_float smooth_alpha=0)
Definition: op.h:6378
Symbol Flatten(const std::string &symbol_name, Symbol data)
Definition: op.h:347
Symbol BlockGrad(const std::string &symbol_name, Symbol data)
Definition: op.h:1331
LeakyReLUActType
Definition: op.h:4064
Symbol arccos(const std::string &symbol_name, Symbol data)
Definition: op.h:3447
Symbol argmax_channel(const std::string &symbol_name, Symbol data)
Definition: op.h:1098
Symbol batch_take(const std::string &symbol_name, Symbol a, Symbol indices)
Definition: op.h:2843
Symbol LinearRegressionOutput(const std::string &symbol_name, Symbol data, Symbol label, mx_float grad_scale=1)
Definition: op.h:4387
Symbol choose_element_0index(const std::string &symbol_name, Symbol lhs, Symbol rhs)
Definition: op.h:7019
Symbol Reshape(const std::string &symbol_name, Symbol data, Shape shape=Shape(), bool reverse=false, Shape target_shape=Shape(), bool keep_highest=false)
Definition: op.h:302
Symbol degrees(const std::string &symbol_name, Symbol data)
Definition: op.h:3499
Symbol one_hot(const std::string &symbol_name, Symbol indices, int depth, double on_value=1, double off_value=0, One_hotDtype dtype=One_hotDtype::kFloat32)
Definition: op.h:2906
Symbol negative(const std::string &symbol_name, Symbol data)
Definition: op.h:1449
Symbol GridGenerator(const std::string &symbol_name, Symbol data, GridGeneratorTransformType transform_type, Shape target_shape=Shape(0, 0))
Definition: op.h:5775
Symbol argmax(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(), bool keepdims=false)
Definition: op.h:1023
Operator interface.
Definition: operator.h:43
Symbol interface.
Definition: symbol.h:72
Symbol sum(const std::string &symbol_name, Symbol data, Shape axis=Shape(), bool keepdims=false, bool exclude=false)
Definition: op.h:2064
MakeLossNormalization
Definition: op.h:6719
Symbol log10(const std::string &symbol_name, Symbol data)
Definition: op.h:1893
Symbol dot(const std::string &symbol_name, Symbol lhs, Symbol rhs, bool transpose_a=false, bool transpose_b=false)
Definition: op.h:1203
Symbol log2(const std::string &symbol_name, Symbol data)
Definition: op.h:1914