32 #ifndef MXNET_LIB_API_H_ 33 #define MXNET_LIB_API_H_ 40 #include <unordered_set> 41 #include <unordered_map> 51 #include <cuda_runtime.h> 52 #include <curand_kernel.h> 56 #define MX_LIBRARY_VERSION 10 63 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) 64 #define PRIVATE_SYMBOL 66 #define PRIVATE_SYMBOL __attribute__ ((visibility ("hidden"))) 72 #ifndef DLPACK_VERSION 74 #define DLPACK_EXTERN_C extern "C" 76 #define DLPACK_EXTERN_C 80 #define DLPACK_VERSION 020 85 #define DLPACK_DLL __declspec(dllexport) 87 #define DLPACK_DLL __declspec(dllimport) 210 uint64_t byte_offset;
227 std::stringstream&
add(
const char* file,
int line);
233 const std::string*
get(
int idx);
241 std::vector<std::stringstream*> messages;
245 #define MX_ERROR_MSG mxnet::ext::MXerrorMsgs::get()->add(__FILE__, __LINE__) 280 explicit MXContext(std::string dev_type_,
int dev_id_);
281 explicit MXContext(
const char* dev_type_,
int dev_id_);
311 int64_t* indptr =
nullptr;
314 void set(
void *data_ptr,
const int64_t* dims,
int ndims,
void *idx,
315 int64_t num_idx,
void *idx_ptr =
nullptr, int64_t num_idx_ptr = 0);
328 void setTensor(
void *dptr,
MXDType type,
const int64_t* dims,
int ndims,
335 template<
typename data_type>
337 return reinterpret_cast<data_type*
>(data_ptr);
341 int64_t
size()
const;
344 bool isSame(
const MXTensor &oth)
const;
371 typedef void* (*xpu_malloc_t)(
void*, int);
375 typedef void (*
nd_malloc_t)(
const void* _ndarray_alloc,
const int64_t* shapes,
int num_shapes,
376 const char* dev_str,
int dev_id,
int dtype,
const char* name,
377 int isArg,
void** data);
379 #if defined(__NVCC__) 390 #define MX_NUM_CPU_RANDOM_STATES 1024 391 #define MX_NUM_GPU_RANDOM_STATES 32768 396 PassResource(std::unordered_map<std::string, MXTensor>* new_args,
397 std::unordered_map<std::string, MXTensor>* new_aux,
401 MXTensor* alloc_arg(
const std::string& name,
const std::vector<int64_t>& shapes,
405 MXTensor* alloc_aux(
const std::string& name,
const std::vector<int64_t>& shapes,
409 std::unordered_map<std::string, MXTensor>* new_args_;
410 std::unordered_map<std::string, MXTensor>* new_aux_;
412 const void* nd_alloc_;
421 xpu_malloc_t gpu_malloc_fp,
void* gpu_alloc_fp,
void* stream,
423 void* rng_cpu_states,
void* rng_gpu_states);
426 void* alloc_cpu(
int size)
const;
429 void* alloc_gpu(
int size)
const;
433 return static_cast<mx_stream_t
>(cuda_stream);
437 void alloc_sparse(
MXSparse* sparse,
int index,
int indices_len,
int indptr_len = 0)
const;
441 mx_cpu_rand_t* get_cpu_rand_states()
const;
447 return static_cast<mx_gpu_rand_t*
>(rand_gpu_states);
454 void *cpu_alloc, *gpu_alloc;
462 void *rand_cpu_states, *rand_gpu_states;
466 #define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json" 468 #define MX_STR_DTYPE "__ext_dtype__" 470 #define MX_STR_SHAPE "__ext_shape__" 472 #define MX_STR_EXTRA_INPUTS "__ext_extra_inputs__" 481 std::string
getShapeAt(
const std::string& shape,
unsigned index);
490 std::string
getDtypeAt(
const std::string& dtype,
unsigned index);
504 explicit JsonVal(std::string s);
509 bool operator<(
const JsonVal &o)
const;
512 std::string dump()
const;
515 static JsonVal parse(
const std::string& json);
518 static JsonVal parse_string(
const std::string& json,
unsigned int* idx);
521 static JsonVal parse_num(
const std::string& json,
unsigned int* idx);
524 static JsonVal parse_list(
const std::string& json,
unsigned int* idx);
527 static JsonVal parse_map(
const std::string& json,
unsigned int* idx);
530 static JsonVal parse(
const std::string& json,
unsigned int *idx);
533 std::string toString()
const;
539 std::map<JsonVal, JsonVal>
map;
563 void alloc_arg(
const std::vector<int64_t>& shapes,
567 void alloc_aux(
const std::vector<int64_t>& shapes,
576 std::unordered_map<std::string, std::string>
attrs;
591 static Graph* fromString(
const std::string& json);
600 std::string toString()
const;
603 void _dfs_util(
Node* n, std::unordered_set<Node*>* to_visit,
604 std::function<
void(
Node*)> handler)
const;
607 void DFS(std::function<
void(
Node*)> handler)
const;
610 std::vector<Node*> topological_sort()
const;
613 void print(
int indent = 0)
const;
616 Node* addNode(
const std::string& name,
const std::string& op);
619 Node* getNode(
size_t idx);
622 const Node* getNode(
size_t idx)
const;
625 const JsonVal& getAttr(
const std::string& key)
const;
634 void _setParams(std::unordered_map<std::string, mxnet::ext::MXTensor>* args,
635 std::unordered_map<std::string, mxnet::ext::MXTensor>* aux);
639 std::map<std::string, JsonVal>
attrs;
642 std::vector<Node*> nodes;
654 virtual bool Select(
int nodeID) = 0;
660 virtual bool SelectInput(
int nodeID,
int input_nodeID) = 0;
666 virtual bool SelectOutput(
int nodeID,
int output_nodeID) = 0;
672 virtual void Filter(
const std::vector<int>& candidates,
673 std::vector<int>* keep) {
674 keep->insert(keep->end(), candidates.begin(), candidates.end());
690 std::vector<MXTensor>* outputs,
693 std::vector<MXTensor>* outputs,
695 MX_ERROR_MSG <<
"Error! Operator does not support backward" << std::endl;
711 std::string>& attributes,
712 std::vector<MXTensor>* inputs,
713 std::vector<MXTensor>* outputs,
716 std::string>& attributes,
717 int* num_inputs,
int* num_outputs);
719 std::string>& attributes,
720 std::vector<int>* in_types,
721 std::vector<int>* out_types);
723 std::string>& attributes,
724 std::vector<int>* in_storage_types,
725 std::vector<int>* out_storage_types);
727 std::string>& attributes,
728 std::vector<std::vector<unsigned int> >* in_shapes,
729 std::vector<std::vector<unsigned int> >* out_shapes);
731 std::string>& attributes,
732 std::vector<int>* input_indices);
734 std::string>& attributes,
736 const std::vector<std::vector<unsigned int> >& in_shapes,
737 const std::vector<int> in_types,
745 explicit CustomOp(
const char* op_name);
784 void raiseDuplicateContextError();
787 std::unordered_map<const char*, fcomp_t> forward_ctx_map, backward_ctx_map;
788 std::unordered_map<const char*, createOpState_t> create_op_ctx_map;
793 const std::unordered_map<std::string, std::string>& options);
813 const std::unordered_map<std::string,
814 std::string>& options);
817 const std::unordered_map<std::string,
818 std::string>& options);
821 const std::unordered_map<std::string,
822 std::string>& options,
823 std::unordered_map<std::string,
824 std::string>* attrs);
836 const char* sg_name);
880 T&
add(
const char* name) {
881 T *entry =
new T(name);
882 entries.push_back(entry);
886 return entries.size();
889 return *(entries.at(idx));
898 std::vector<T*> entries;
906 #define MX_STR_CONCAT_(__a, __b) __a ## __b 907 #define MX_STR_CONCAT(__a, __b) MX_STR_CONCAT_(__a, __b) 910 #define MX_STRINGIFY(x) #x 911 #define MX_TOSTRING(x) MX_STRINGIFY(x) 914 #define MX_REGISTER_NAME_(Name) MXNet ## _CustomOp ## _ 915 #define MX_REGISTER_DEF_(Name) mxnet::ext::CustomOp MX_REGISTER_NAME_(Name) 917 #define MX_REGISTER_PROP_NAME_(Name) MXNet ## _CustomSubProp ## _ 918 #define MX_REGISTER_PROP_DEF_(Name) mxnet::ext::CustomPartitioner MX_REGISTER_PROP_NAME_(Name) 920 #define MX_REGISTER_PASS_NAME_(Name) MXNet ## _CustomPass ## _ 921 #define MX_REGISTER_PASS_DEF_(Name) mxnet::ext::CustomPass MX_REGISTER_PASS_NAME_(Name) 924 #define REGISTER_OP(Name) MX_STR_CONCAT(MX_REGISTER_DEF_(Name), __COUNTER__) = \ 925 mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->add(MX_TOSTRING(Name)) 927 #define REGISTER_PARTITIONER(Name) \ 928 MX_STR_CONCAT(MX_REGISTER_PROP_DEF_(Name), __COUNTER__) = \ 929 mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->add(MX_TOSTRING(Name)) 931 #define REGISTER_PASS(Name) \ 932 MX_STR_CONCAT(MX_REGISTER_PASS_DEF_(Name), __COUNTER__) = \ 933 mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->add(MX_TOSTRING(Name)) 942 #define MXLIB_OPREGSIZE_STR "_opRegSize" 945 #define MXLIB_OPREGGET_STR "_opRegGet" 946 typedef int (*
opRegGet_t)(
int idx,
const char** name,
int *isSGop,
948 int* forward_count,
const char*** backward_ctx,
955 #define MXLIB_OPCALLFREE_STR "_opCallFree" 958 #define MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs" 960 const char*
const* vals,
int num,
961 int* num_in,
int* num_out);
963 #define MXLIB_OPCALLINFERSHAPE_STR "_opCallInferShape" 965 const char*
const* vals,
int num,
966 unsigned int** inshapes,
int* indims,
int num_in,
967 unsigned int*** mod_inshapes,
int** mod_indims,
968 unsigned int*** outshapes,
int** outdims,
int num_out);
970 #define MXLIB_OPCALLINFERTYPE_STR "_opCallInferType" 972 const char*
const* vals,
int num,
973 int* intypes,
int num_in,
int* outtypes,
int num_out);
975 #define MXLIB_OPCALLINFERSTYPE_STR "_opCallInferSType" 977 const char*
const* vals,
int num,
978 int* intypes,
int num_in,
int* outtypes,
int num_out);
980 #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute" 982 const char*
const* vals,
int num,
983 const int64_t** inshapes,
int* indims,
984 void** indata,
int* intypes,
985 size_t* inIDs,
const char** indev_type,
986 int* indev_id,
int num_in,
987 const int64_t** outshapes,
int* outdims,
988 void** outdata,
int* outtypes,
989 size_t* outIDs,
const char** outdev_type,
990 int* outdev_id,
int num_out,
992 xpu_malloc_t gpu_malloc,
void* gpu_alloc,
void* cuda_stream,
994 int* instypes,
int* outstypes,
995 void** in_indices,
void** out_indices,
996 void** in_indptr,
void** out_indptr,
997 int64_t* in_indices_shapes, int64_t* out_indices_shapes,
998 int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
999 void* rng_cpu_states,
void* rng_gpu_states);
1001 #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs" 1003 const char*
const* vals,
int num,
1004 int** mutate_indices,
int* indices_size);
1006 #define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState" 1008 const char*
const* vals,
int num,
const char* dev_type,
1009 int dev_id,
unsigned int** inshapes,
int* indims,
1010 int num_in,
const int* intypes,
void** state_op);
1012 #define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute" 1014 const int64_t** inshapes,
int* indims,
1015 void** indata,
int* intypes,
1016 size_t* inIDs,
const char** indev_type,
1017 int* indev_id,
int num_in,
1018 const int64_t** outshapes,
int* outdims,
1019 void** outdata,
int* outtypes,
1020 size_t* outIDs,
const char** outdev_type,
1021 int* outdev_id,
int num_out,
1023 xpu_malloc_t gpu_malloc,
void* gpu_alloc,
void* stream,
1025 int* instypes,
int* outstypes,
1026 void** in_indices,
void** out_indices,
1027 void** in_indptr,
void** out_indptr,
1028 int64_t* in_indices_shapes, int64_t* out_indices_shapes,
1029 int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
1030 void* rng_cpu_states,
void* rng_gpu_states);
1032 #define MXLIB_PARTREGSIZE_STR "_partRegSize" 1035 #define MXLIB_PARTREGGETCOUNT_STR "_partRegGetCount" 1038 #define MXLIB_PARTREGGET_STR "_partRegGet" 1039 typedef void (*
partRegGet_t)(
int part_idx,
int stg_idx,
const char** strategy,
1043 #define MXLIB_PARTCALLSUPPORTEDOPS_STR "_partCallSupportedOps" 1045 int num_ids,
int *ids,
const char*
const* opt_keys,
1046 const char*
const* opt_vals,
int num_opts);
1048 #define MXLIB_PARTCALLCREATESELECTOR_STR "_partCallCreateSelector" 1050 void** selector,
const char*
const* opt_keys,
1051 const char*
const* opt_vals,
int num_opts);
1053 #define MXLIB_PARTCALLSELECT_STR "_partCallSelect" 1056 #define MXLIB_PARTCALLSELECTINPUT_STR "_partCallSelectInput" 1060 #define MXLIB_PARTCALLSELECTOUTPUT_STR "_partCallSelectOutput" 1064 #define MXLIB_PARTCALLFILTER_STR "_partCallFilter" 1066 int** keep,
int* num_keep);
1068 #define MXLIB_PARTCALLRESET_STR "_partCallReset" 1071 #define MXLIB_PARTCALLREVIEWSUBGRAPH_STR "_partCallReviewSubgraph" 1073 int subgraph_id,
int *accept,
const char*
const* opt_keys,
1074 const char*
const* opt_vals,
int num_opts,
1075 char*** attr_keys,
char*** attr_vals,
int *num_attrs,
1076 const char*
const* arg_names,
int num_args,
1077 void*
const* arg_data,
const int64_t*
const* arg_shapes,
1078 const int* arg_dims,
const int* arg_types,
1079 const size_t* arg_IDs,
const char*
const* arg_dev_type,
1080 const int* arg_dev_id,
1081 const char*
const* aux_names,
int num_aux,
1082 void*
const* aux_data,
const int64_t*
const* aux_shapes,
1083 const int* aux_dims,
const int* aux_types,
1084 const size_t* aux_IDs,
const char*
const* aux_dev_type,
1085 const int* aux_dev_id);
1087 #define MXLIB_PASSREGSIZE_STR "_passRegSize" 1090 #define MXLIB_PASSREGGET_STR "_passRegGet" 1093 #define MXLIB_PASSCALLGRAPHPASS_STR "_passCallGraphPass" 1095 char** out_graph,
const char*
const* opt_keys,
1096 const char*
const* opt_vals,
int num_opts,
1097 const char* pass_name,
const char*
const* arg_names,
1098 int num_args,
void*
const* arg_data,
1099 const int64_t*
const* arg_shapes,
const int* arg_dims,
1100 const int* arg_types,
const size_t* arg_IDs,
1101 const char*
const* arg_dev_type,
const int* arg_dev_id,
1102 const char*
const* aux_names,
int num_aux,
1103 void*
const* aux_data,
const int64_t*
const* aux_shapes,
1104 const int* aux_dims,
const int* aux_types,
1105 const size_t* aux_IDs,
const char*
const* aux_dev_type,
1107 const void* nd_alloc);
1109 #define MXLIB_INITIALIZE_STR "initialize" 1112 #define MXLIB_OPVERSION_STR "_opVersion" 1115 #define MXLIB_MSGSIZE_STR "_msgSize" 1118 #define MXLIB_MSGGET_STR "_msgGet" 1121 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) 1122 #define MX_INT_RET __declspec(dllexport) int __cdecl 1123 #define MX_VOID_RET __declspec(dllexport) void __cdecl 1125 #define MX_INT_RET int 1126 #define MX_VOID_RET void 1142 int* forward_count,
const char*** backward_ctx,
1154 const char*
const* vals,
int num,
1155 int* num_in,
int* num_out);
1159 const char*
const* vals,
int num,
1160 unsigned int** inshapes,
int* indims,
int num_in,
1161 unsigned int*** mod_inshapes,
int** mod_indims,
1162 unsigned int*** outshapes,
int** outdims,
int num_out);
1166 const char*
const* vals,
int num,
1167 int* intypes,
int num_in,
int* outtypes,
int num_out);
1171 const char*
const* vals,
int num,
1172 int* instypes,
int num_in,
int* outstypes,
int num_out);
1176 const char*
const* vals,
1177 int num,
const int64_t** inshapes,
int* indims,
void** indata,
1178 int* intypes,
size_t* inIDs,
const char** indev_type,
int* indev_id,
1179 int num_in,
const int64_t** outshapes,
int* outdims,
void** outdata,
1180 int* outtypes,
size_t* outIDs,
const char** outdev_type,
1186 int* instypes,
int* outstypes,
void** in_indices,
void** out_indices,
1187 void** in_indptr,
void** out_indptr,
1188 int64_t* in_indices_shapes, int64_t* out_indices_shapes,
1189 int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
1190 void* rng_cpu_states,
void* rng_gpu_states);
1194 const char*
const* vals,
int num,
1195 int** mutate_indices,
int* indices_size);
1199 const char*
const* vals,
int num,
const char* dev_type,
1200 int dev_id,
unsigned int** inshapes,
int* indims,
1201 int num_in,
const int* intypes,
void** state_op);
1205 int* indims,
void** indata,
int* intypes,
size_t* inIDs,
1206 const char** indev_type,
int* indev_id,
int num_in,
1207 const int64_t** outshapes,
int* outdims,
void** outdata,
1208 int* outtypes,
size_t* outIDs,
const char** outdev_type,
1209 int* outdev_id,
int num_out,
1214 void* sparse_alloc,
int* instypes,
int* outstypes,
1215 void** in_indices,
void** out_indices,
void** in_indptr,
1216 void** out_indptr, int64_t* in_indices_shapes,
1217 int64_t* out_indices_shapes, int64_t* in_indptr_shapes,
1218 int64_t* out_indptr_shapes,
1219 void* rng_cpu_states,
void* rng_gpu_states);
1236 int num_ids,
int *ids,
const char*
const* opt_keys,
1237 const char*
const* opt_vals,
int num_opts);
1241 void** selector,
const char*
const* opt_keys,
1242 const char*
const* opt_vals,
int num_opts);
1249 int input_nodeID,
int* selected);
1253 int output_nodeID,
int* selected);
1257 int** keep,
int* num_keep);
1264 int subgraph_id,
int *accept,
const char*
const* opt_keys,
1265 const char*
const* opt_vals,
int num_opts,
1266 char*** attr_keys,
char*** attr_vals,
int *num_attrs,
1267 const char*
const* arg_names,
int num_args,
1268 void*
const* arg_data,
const int64_t*
const* arg_shapes,
1269 const int* arg_dims,
const int* arg_types,
1270 const size_t* arg_IDs,
const char*
const* arg_dev_type,
1271 const int* arg_dev_id,
1272 const char*
const* aux_names,
int num_aux,
1273 void*
const* aux_data,
const int64_t*
const* aux_shapes,
1274 const int* aux_dims,
const int* aux_types,
1275 const size_t* aux_IDs,
const char*
const* aux_dev_type,
1276 const int* aux_dev_id);
1283 const char** pass_name);
1287 char** out_graph,
const char*
const* opt_keys,
1288 const char*
const* opt_vals,
int num_opts,
1289 const char* pass_name,
const char*
const* arg_names,
int num_args,
1290 void*
const* arg_data,
const int64_t*
const* arg_shapes,
1291 const int* arg_dims,
const int* arg_types,
1292 const size_t* arg_IDs,
const char*
const* arg_dev_type,
1293 const int* arg_dev_id,
const char*
const* aux_names,
int num_aux,
1294 void*
const* aux_data,
const int64_t*
const* aux_shapes,
1295 const int* aux_dims,
const int* aux_types,
1296 const size_t* aux_IDs,
const char*
const* aux_dev_type,
1298 const void* nd_alloc);
1307 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) 1320 #endif // MXNET_LIB_API_H_ An abstract class for subgraph property.
Definition: lib_api.h:829
int(* opVersion_t)()
Definition: lib_api.h:1113
int(* partCallSupportedOps_t)(supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1044
const char * name
partitioner name
Definition: lib_api.h:851
MXReturnValue(* inferShape_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< std::vector< unsigned int > > *in_shapes, std::vector< std::vector< unsigned int > > *out_shapes)
Definition: lib_api.h:726
DLDeviceType
The device type in DLContext.
Definition: lib_api.h:102
std::string getShapeAt(const std::string &shape, unsigned index)
int(* opCallParseAttrs_t)(parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
Definition: lib_api.h:959
MX_VOID_RET _opRegGet(int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate)
returns operator registration at specified index
virtual MXReturnValue Backward(std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &op_res)
Definition: lib_api.h:692
int(* msgGet_t)(int idx, const char **msg)
Definition: lib_api.h:1119
const char * name
operator name
Definition: lib_api.h:768
Definition: lib_api.h:258
MXTensor * tensor
Definition: lib_api.h:572
virtual void Filter(const std::vector< int > &candidates, std::vector< int > *keep)
Definition: lib_api.h:672
provide resource APIs memory allocation mechanism to Forward/Backward functions
Definition: lib_api.h:418
Definition: lib_api.h:256
void(* partCallReset_t)(void *sel_inst)
Definition: lib_api.h:1069
std::vector< const char * > strategies
strategy names
Definition: lib_api.h:856
OpenCL devices.
Definition: lib_api.h:113
std::vector< NodeEntry > outputs
Definition: lib_api.h:574
DLDataTypeCode
The type code options DLDataType.
Definition: lib_api.h:143
MX_VOID_RET _partCallSelectOutput(void *sel_inst, int nodeID, int output_nodeID, int *selected)
returns status of calling select output function from library
StatefulOp wrapper class to pass to backend OpState.
Definition: lib_api.h:701
MX_INT_RET _opRegSize()
returns number of ops registered in this library
T & add(const char *name)
add a new entry
Definition: lib_api.h:880
Metal for Apple GPU.
Definition: lib_api.h:117
Definition: lib_api.h:297
int(* passCallGraphPass_t)(graphPass_t graphPass, const char *in_graph, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, nd_malloc_t nd_malloc, const void *nd_alloc)
Definition: lib_api.h:1094
namespace of mxnet
Definition: api_registry.h:33
Definition: lib_api.h:146
void(* partCallSelectInput_t)(void *sel_inst, int nodeID, int input_nodeID, int *selected)
Definition: lib_api.h:1057
#define MX_VOID_RET
Definition: lib_api.h:1126
ROCm GPUs for AMD GPUs.
Definition: lib_api.h:121
A Device context for Tensor and operator.
Definition: dlpack.h:69
std::vector< Graph * > subgraphs
Definition: lib_api.h:575
CUDA GPU device.
Definition: lib_api.h:106
Node * node
Definition: lib_api.h:550
void(* partRegGet_t)(int part_idx, int stg_idx, const char **strategy, supportedOps_t *supportedOps, createSelector_t *createSelector, reviewSubgraph_t *reviewSubgraph, const char **op_name)
Definition: lib_api.h:1039
MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
returns status of calling mutateInputs function for operator from library
int64_t * indices
Definition: lib_api.h:306
int64_t data_len
Definition: lib_api.h:301
MXStorageType
Definition: lib_api.h:264
Reserved extension device type, used for quickly test extension device The semantics can differ depen...
Definition: lib_api.h:127
Pinned CUDA GPU device by cudaMallocHost.
Definition: lib_api.h:111
inferSType_t infer_storage_type
Definition: lib_api.h:773
MXContext ctx
Definition: lib_api.h:360
#define PRIVATE_SYMBOL
For loading multiple custom op libraries in Linux, exporting same symbol multiple times may lead to u...
Definition: lib_api.h:66
MX_INT_RET _partRegSize()
returns number of partitioners registered in this library
Definition: lib_api.h:496
MXReturnValue(* graphPass_t)(mxnet::ext::Graph *graph, const std::unordered_map< std::string, std::string > &options)
Custom Pass Create function template.
Definition: lib_api.h:792
MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling create selector function from library
MX_VOID_RET _partCallSelect(void *sel_inst, int nodeID, int *selected)
returns status of calling select function from library
Definition: lib_api.h:255
int num
Definition: lib_api.h:536
Tensor data structure used by custom operator.
Definition: lib_api.h:321
MX_VOID_RET _partCallFilter(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
returns status of calling filter function from library
std::unordered_map< std::string, std::string > attrs
Definition: lib_api.h:576
void *(* xpu_malloc_t)(void *, int)
resource malloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:371
int(* opCallInferShape_t)(inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
Definition: lib_api.h:964
MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *instypes, int num_in, int *outstypes, int num_out)
returns status of calling inferSType function for operator from library
MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling supported ops function from library
MX_VOID_RET _partCallSelectInput(void *sel_inst, int nodeID, int input_nodeID, int *selected)
returns status of calling select input function from library
MXReturnValue(* createOpState_t)(const std::unordered_map< std::string, std::string > &attributes, const MXContext &ctx, const std::vector< std::vector< unsigned int > > &in_shapes, const std::vector< int > in_types, CustomStatefulOp **)
Definition: lib_api.h:733
MXReturnValue
Definition: lib_api.h:291
void(* partCallSelectOutput_t)(void *sel_inst, int nodeID, int output_nodeID, int *selected)
Definition: lib_api.h:1061
std::string name
Definition: lib_api.h:571
MX_INT_RET _opCallFStatefulCompute(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Stateful Forward/Backward for operator from library
CustomStatefulOp * get_instance()
Definition: lib_api.h:704
std::map< std::string, createSelector_t > selector_map
Definition: lib_api.h:853
std::vector< NodeEntry > inputs
Definition: lib_api.h:573
Definition: lib_api.h:254
CustomStatefulOpWrapper(CustomStatefulOp *inst)
Definition: lib_api.h:703
graphPass_t pass
pass function
Definition: lib_api.h:808
Definition: lib_api.h:257
std::string str
Definition: lib_api.h:537
MXReturnValue(* createSelector_t)(const mxnet::ext::Graph *graph, CustomOpSelector **sel_inst, const std::unordered_map< std::string, std::string > &options)
Definition: lib_api.h:815
MXDType dtype
Definition: lib_api.h:354
int(* opCallMutateInputs_t)(mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
Definition: lib_api.h:1002
Definition: lib_api.h:144
inferType_t infer_type
Definition: lib_api.h:772
std::map< std::string, JsonVal > attrs
Definition: lib_api.h:639
void(* partCallSelect_t)(void *sel_inst, int nodeID, int *selected)
Definition: lib_api.h:1054
int entry
Definition: lib_api.h:551
Definition: lib_api.h:251
Definition: lib_api.h:221
Class to hold custom operator registration.
Definition: lib_api.h:743
MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
returns status of calling inferType function for operator from library
Definition: lib_api.h:583
int(* opRegGet_t)(int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate)
Definition: lib_api.h:946
Vulkan buffer for next generation graphics.
Definition: lib_api.h:115
MX_INT_RET _opVersion()
returns MXNet library version
inferShape_t infer_shape
Definition: lib_api.h:774
#define MX_ERROR_MSG
Definition: lib_api.h:245
int64_t indices_len
Definition: lib_api.h:307
MXReturnValue(* inferSType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_storage_types, std::vector< int > *out_storage_types)
Definition: lib_api.h:722
std::vector< const char * > op_names
subgraph operator name
Definition: lib_api.h:858
MXReturnValue(* reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id, bool *accept, const std::unordered_map< std::string, std::string > &options, std::unordered_map< std::string, std::string > *attrs)
Definition: lib_api.h:819
DLDeviceType
The device type in DLContext.
Definition: dlpack.h:38
MX_INT_RET _partRegGetCount(int idx, const char **name)
void * mx_gpu_rand_t
Definition: lib_api.h:384
Verilog simulator buffer.
Definition: lib_api.h:119
mx_stream_t get_cuda_stream() const
return the cuda stream object with correct type
Definition: lib_api.h:432
std::vector< createOpState_t > create_op_fp
Definition: lib_api.h:781
MXReturnValue(* fcomp_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &res)
Custom Operator function templates.
Definition: lib_api.h:710
mxnet::ext::MXReturnValue initialize(int version)
Checks if the MXNet version is supported by the library. If supported, initializes the library...
JsonType
Json utility to parse serialized subgraph symbol.
Definition: lib_api.h:496
MX_VOID_RET _opCallFree(void *ptr)
calls free from the external library for library allocated arrays
MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
returns status of calling review subgraph function from library
int(* opCallInferSType_t)(inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:976
MX_INT_RET _passRegSize()
returns number of graph passes registered in this library
Definition: lib_api.h:145
int(* passRegSize_t)(void)
Definition: lib_api.h:1088
definition of JSON objects
Definition: lib_api.h:499
Definition: lib_api.h:549
MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *json, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, mxnet::ext::nd_malloc_t nd_malloc, const void *nd_alloc)
returns status of calling graph pass function from library
int(* partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
Definition: lib_api.h:1072
std::string getDtypeAt(const std::string &dtype, unsigned index)
bool isSGop
Definition: lib_api.h:776
MXReturnValue(* inferType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_types, std::vector< int > *out_types)
Definition: lib_api.h:718
int(* opCallFree_t)(void *ptr)
Definition: lib_api.h:956
int(* opCallFStatefulComp_t)(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:1013
size_t verID
Definition: lib_api.h:357
std::vector< JsonVal > list
Definition: lib_api.h:538
MXReturnValue(* supportedOps_t)(const mxnet::ext::Graph *graph, std::vector< int > *ids, const std::unordered_map< std::string, std::string > &options)
Custom Subgraph Create function template.
Definition: lib_api.h:812
#define MX_INT_RET
Definition: lib_api.h:1125
const char * name
pass name
Definition: lib_api.h:806
parseAttrs_t parse_attrs
operator functions
Definition: lib_api.h:771
Definition: lib_api.h:394
data_type * data()
helper function to cast data pointer
Definition: lib_api.h:336
MXReturnValue(* parseAttrs_t)(const std::unordered_map< std::string, std::string > &attributes, int *num_inputs, int *num_outputs)
Definition: lib_api.h:715
mutateInputs_t mutate_inputs
Definition: lib_api.h:775
Context info passing from MXNet OpContext dev_type is string repr of supported context, currently only "cpu" and "gpu" dev_id is the device index where the tensor locates.
Definition: lib_api.h:278
std::mt19937 mx_cpu_rand_t
Definition: lib_api.h:386
MX_VOID_RET _msgGet(int idx, const char **msg)
returns operator registration at specified index
MX_VOID_RET _partCallReset(void *sel_inst)
returns status of calling reset selector function from library
Definition: lib_api.h:292
void * mx_stream_t
GPU stream pointer, is void* when not compiled with CUDA.
Definition: lib_api.h:383
An abstract class for library authors creating stateful op custom library should override Forward and...
Definition: lib_api.h:687
int(* partCallCreateSelector_t)(createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1049
std::map< std::string, supportedOps_t > supported_map
Definition: lib_api.h:852
int size()
Definition: lib_api.h:885
Definition: lib_api.h:293
std::vector< const char * > forward_ctx_cstr
vector repr of ctx map to be easily loaded from c_api
Definition: lib_api.h:779
CPU device.
Definition: lib_api.h:104
int(* partRegSize_t)(void)
Definition: lib_api.h:1033
Definition: lib_api.h:555
void(* passRegGet_t)(int pass_idx, graphPass_t *graphPass, const char **pass_name)
Definition: lib_api.h:1091
Definition: lib_api.h:496
void(* sparse_malloc_t)(void *, int, int, int, void **, int64_t **, int64_t **)
sparse alloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:373
MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t *graphPass, const char **pass_name)
returns pass registration at specified index
void(* partCallFilter_t)(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
Definition: lib_api.h:1065
std::vector< NodeEntry > outputs
Definition: lib_api.h:638
int64_t indptr_len
Definition: lib_api.h:312
int dev_id
Definition: lib_api.h:288
std::vector< Node * > inputs
Definition: lib_api.h:637
DLTensor dltensor
Definition: lib_api.h:364
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op)
returns status of calling createStatefulOp function for operator from library
int(* opCallInferType_t)(inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:971
std::string dev_type
Definition: lib_api.h:287
int(* opCallCreateOpState_t)(createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op)
Definition: lib_api.h:1007
MXDType
Tensor data type, consistent with mshadow data type.
Definition: lib_api.h:250
int(* partRegGetCount_t)(int idx, const char **name)
Definition: lib_api.h:1036
std::vector< fcomp_t > forward_fp
Definition: lib_api.h:780
Definition: lib_api.h:496
MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
returns status of calling inferShape function for operator from library
int(* initialize_t)(int version)
Definition: lib_api.h:1110
std::map< JsonVal, JsonVal > map
Definition: lib_api.h:539
Definition: lib_api.h:270
int(* opRegSize_t)(void)
Definition: lib_api.h:943
An abstract class for graph passes.
Definition: lib_api.h:798
std::vector< int64_t > shape
Definition: lib_api.h:351
Registry class to registers things (ops, properties) Singleton class.
Definition: lib_api.h:866
JsonType type
Definition: lib_api.h:535
Definition: lib_api.h:268
The data type the tensor can hold.
Definition: dlpack.h:94
void(* nd_malloc_t)(const void *_ndarray_alloc, const int64_t *shapes, int num_shapes, const char *dev_str, int dev_id, int dtype, const char *name, int isArg, void **data)
resource malloc function to allocate ndarrays for graph passes
Definition: lib_api.h:375
Plain C Tensor object, does not manage memory.
Definition: dlpack.h:112
MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Forward/Backward function for operator from library
Definition: lib_api.h:649
std::stringstream & add(const char *file, int line)
virtual void Reset()
Definition: lib_api.h:679
Definition: lib_api.h:266
std::string op
Definition: lib_api.h:570
Definition: lib_api.h:252
Definition: lib_api.h:253
Definition: lib_api.h:496
Definition: lib_api.h:496
MXReturnValue(* mutateInputs_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *input_indices)
Definition: lib_api.h:730
int(* msgSize_t)(void)
Definition: lib_api.h:1116
mx_gpu_rand_t * get_gpu_rand_states() const
get pointer to initialized and seeded random number states located on GPU
Definition: lib_api.h:446
std::map< std::string, reviewSubgraph_t > review_map
Definition: lib_api.h:854
MXStorageType stype
Definition: lib_api.h:367
int(* opCallFComp_t)(fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:981
MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char **strategy, mxnet::ext::supportedOps_t *supportedOps, mxnet::ext::createSelector_t *createSelector, mxnet::ext::reviewSubgraph_t *reviewSubgraph, const char **op_name)
returns partitioner registration at specified index
MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
returns status of calling parse attributes function for operator from library
void * data_ptr
Definition: lib_api.h:348