31 #ifndef MXNET_LIB_API_H_ 32 #define MXNET_LIB_API_H_ 39 #include <unordered_set> 40 #include <unordered_map> 50 #include <cuda_runtime.h> 51 #include <curand_kernel.h> 55 #define MX_LIBRARY_VERSION 11 62 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) 63 #define PRIVATE_SYMBOL 65 #define PRIVATE_SYMBOL __attribute__ ((visibility ("hidden"))) 71 #ifndef DLPACK_VERSION 73 #define DLPACK_EXTERN_C extern "C" 75 #define DLPACK_EXTERN_C 79 #define DLPACK_VERSION 020 84 #define DLPACK_DLL __declspec(dllexport) 86 #define DLPACK_DLL __declspec(dllimport) 209 uint64_t byte_offset;
226 std::stringstream&
add(
const char* file,
int line);
232 const std::string*
get(
int idx);
240 std::vector<std::stringstream*> messages;
244 #define MX_ERROR_MSG mxnet::ext::MXerrorMsgs::get()->add(__FILE__, __LINE__) 279 explicit MXContext(std::string dev_type_,
int dev_id_);
280 explicit MXContext(
const char* dev_type_,
int dev_id_);
310 int64_t* indptr =
nullptr;
313 void set(
void *data_ptr,
const int64_t* dims,
int ndims,
void *idx,
314 int64_t num_idx,
void *idx_ptr =
nullptr, int64_t num_idx_ptr = 0);
327 void setTensor(
void *dptr,
MXDType type,
const int64_t* dims,
int ndims,
334 template<
typename data_type>
336 return reinterpret_cast<data_type*
>(data_ptr);
340 int64_t
size()
const;
343 bool isSame(
const MXTensor &oth)
const;
370 typedef void* (*xpu_malloc_t)(
void*, int);
374 typedef void (*
nd_malloc_t)(
const void* _ndarray_alloc,
const int64_t* shapes,
int num_shapes,
375 const char* dev_str,
int dev_id,
int dtype,
const char* name,
376 int isArg,
void** data);
378 #if defined(__NVCC__) 389 #define MX_NUM_CPU_RANDOM_STATES 1024 390 #define MX_NUM_GPU_RANDOM_STATES 32768 395 PassResource(std::unordered_map<std::string, MXTensor>* new_args,
396 std::unordered_map<std::string, MXTensor>* new_aux,
400 MXTensor* alloc_arg(
const std::string& name,
const std::vector<int64_t>& shapes,
404 MXTensor* alloc_aux(
const std::string& name,
const std::vector<int64_t>& shapes,
408 std::unordered_map<std::string, MXTensor>* new_args_;
409 std::unordered_map<std::string, MXTensor>* new_aux_;
411 const void* nd_alloc_;
420 xpu_malloc_t gpu_malloc_fp,
void* gpu_alloc_fp,
void* stream,
422 void* rng_cpu_states,
void* rng_gpu_states);
425 void* alloc_cpu(
int size)
const;
428 void* alloc_gpu(
int size)
const;
432 return static_cast<mx_stream_t
>(cuda_stream);
436 void alloc_sparse(
MXSparse* sparse,
int index,
int indices_len,
int indptr_len = 0)
const;
440 mx_cpu_rand_t* get_cpu_rand_states()
const;
446 return static_cast<mx_gpu_rand_t*
>(rand_gpu_states);
453 void *cpu_alloc, *gpu_alloc;
461 void *rand_cpu_states, *rand_gpu_states;
465 #define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json" 467 #define MX_STR_DTYPE "__ext_dtype__" 469 #define MX_STR_SHAPE "__ext_shape__" 471 #define MX_STR_EXTRA_INPUTS "__ext_extra_inputs__" 480 std::string
getShapeAt(
const std::string& shape,
unsigned index);
489 std::string
getDtypeAt(
const std::string& dtype,
unsigned index);
503 explicit JsonVal(std::string s);
508 bool operator<(
const JsonVal &o)
const;
511 std::string dump()
const;
514 static JsonVal parse(
const std::string& json);
517 static JsonVal parse_string(
const std::string& json,
unsigned int* idx);
520 static JsonVal parse_num(
const std::string& json,
unsigned int* idx);
523 static JsonVal parse_list(
const std::string& json,
unsigned int* idx);
526 static JsonVal parse_map(
const std::string& json,
unsigned int* idx);
529 static JsonVal parse(
const std::string& json,
unsigned int *idx);
532 std::string toString()
const;
538 std::map<JsonVal, JsonVal>
map;
562 void alloc_arg(
const std::vector<int64_t>& shapes,
566 void alloc_aux(
const std::vector<int64_t>& shapes,
575 std::unordered_map<std::string, std::string>
attrs;
590 static Graph* fromString(
const std::string& json);
599 std::string toString()
const;
602 void _dfs_util(
Node* n, std::unordered_set<Node*>* to_visit,
603 std::function<
void(
Node*)> handler)
const;
606 void DFS(std::function<
void(
Node*)> handler)
const;
609 std::vector<Node*> topological_sort()
const;
612 void print(
int indent = 0)
const;
615 Node* addNode(
const std::string& name,
const std::string& op);
618 Node* getNode(
size_t idx);
621 const Node* getNode(
size_t idx)
const;
624 const JsonVal& getAttr(
const std::string& key)
const;
633 void _setParams(std::unordered_map<std::string, mxnet::ext::MXTensor>* args,
634 std::unordered_map<std::string, mxnet::ext::MXTensor>* aux);
638 std::map<std::string, JsonVal>
attrs;
641 std::vector<Node*> nodes;
653 virtual bool Select(
int nodeID) = 0;
659 virtual bool SelectInput(
int nodeID,
int input_nodeID) = 0;
665 virtual bool SelectOutput(
int nodeID,
int output_nodeID) = 0;
671 virtual void Filter(
const std::vector<int>& candidates,
672 std::vector<int>* keep) {
673 keep->insert(keep->end(), candidates.begin(), candidates.end());
691 template<
class A,
typename ...Ts>
701 std::vector<MXTensor>* outputs,
704 std::vector<MXTensor>* outputs,
706 MX_ERROR_MSG <<
"Error! Operator does not support backward" << std::endl;
718 std::string>& attributes,
719 std::vector<MXTensor>* inputs,
720 std::vector<MXTensor>* outputs,
723 std::string>& attributes,
724 int* num_inputs,
int* num_outputs);
726 std::string>& attributes,
727 std::vector<int>* in_types,
728 std::vector<int>* out_types);
730 std::string>& attributes,
731 std::vector<int>* in_storage_types,
732 std::vector<int>* out_storage_types);
734 std::string>& attributes,
735 std::vector<std::vector<unsigned int> >* in_shapes,
736 std::vector<std::vector<unsigned int> >* out_shapes);
738 std::string>& attributes,
739 std::vector<int>* input_indices);
741 std::string>& attributes,
743 const std::vector<std::vector<unsigned int> >& in_shapes,
744 const std::vector<int> in_types,
752 explicit CustomOp(
const char* op_name);
791 void raiseDuplicateContextError();
794 std::unordered_map<const char*, fcomp_t> forward_ctx_map, backward_ctx_map;
795 std::unordered_map<const char*, createOpState_t> create_op_ctx_map;
800 const std::unordered_map<std::string, std::string>& options);
820 const std::unordered_map<std::string,
821 std::string>& options);
824 const std::unordered_map<std::string,
825 std::string>& options);
828 const std::unordered_map<std::string,
829 std::string>& options,
830 std::unordered_map<std::string,
831 std::string>* attrs);
843 const char* sg_name);
887 T&
add(
const char* name) {
888 T *entry =
new T(name);
889 entries.push_back(entry);
893 return entries.size();
896 return *(entries.at(idx));
905 std::vector<T*> entries;
913 #define MX_STR_CONCAT_(__a, __b) __a ## __b 914 #define MX_STR_CONCAT(__a, __b) MX_STR_CONCAT_(__a, __b) 917 #define MX_STRINGIFY(x) #x 918 #define MX_TOSTRING(x) MX_STRINGIFY(x) 921 #define MX_REGISTER_NAME_(Name) MXNet ## _CustomOp ## _ ## Name 922 #define MX_REGISTER_DEF_(Name) mxnet::ext::CustomOp MX_REGISTER_NAME_(Name) 924 #define MX_REGISTER_PROP_NAME_(Name) MXNet ## _CustomSubProp ## _ ## Name 925 #define MX_REGISTER_PROP_DEF_(Name) mxnet::ext::CustomPartitioner MX_REGISTER_PROP_NAME_(Name) 927 #define MX_REGISTER_PASS_NAME_(Name) MXNet ## _CustomPass ## _ ## Name 928 #define MX_REGISTER_PASS_DEF_(Name) mxnet::ext::CustomPass MX_REGISTER_PASS_NAME_(Name) 931 #define REGISTER_OP(Name) MX_STR_CONCAT(MX_REGISTER_DEF_(Name), __COUNTER__) = \ 932 mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->add(MX_TOSTRING(Name)) 934 #define REGISTER_PARTITIONER(Name) \ 935 MX_STR_CONCAT(MX_REGISTER_PROP_DEF_(Name), __COUNTER__) = \ 936 mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->add(MX_TOSTRING(Name)) 938 #define REGISTER_PASS(Name) \ 939 MX_STR_CONCAT(MX_REGISTER_PASS_DEF_(Name), __COUNTER__) = \ 940 mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->add(MX_TOSTRING(Name)) 949 #define MXLIB_OPREGSIZE_STR "_opRegSize" 952 #define MXLIB_OPREGGET_STR "_opRegGet" 953 typedef int (*
opRegGet_t)(
int idx,
const char** name,
int *isSGop,
955 int* forward_count,
const char*** backward_ctx,
962 #define MXLIB_OPCALLFREE_STR "_opCallFree" 965 #define MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs" 967 const char*
const* vals,
int num,
968 int* num_in,
int* num_out);
970 #define MXLIB_OPCALLINFERSHAPE_STR "_opCallInferShape" 972 const char*
const* vals,
int num,
973 unsigned int** inshapes,
int* indims,
int num_in,
974 unsigned int*** mod_inshapes,
int** mod_indims,
975 unsigned int*** outshapes,
int** outdims,
int num_out);
977 #define MXLIB_OPCALLINFERTYPE_STR "_opCallInferType" 979 const char*
const* vals,
int num,
980 int* intypes,
int num_in,
int* outtypes,
int num_out);
982 #define MXLIB_OPCALLINFERSTYPE_STR "_opCallInferSType" 984 const char*
const* vals,
int num,
985 int* intypes,
int num_in,
int* outtypes,
int num_out);
987 #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute" 989 const char*
const* vals,
int num,
990 const int64_t** inshapes,
int* indims,
991 void** indata,
int* intypes,
992 size_t* inIDs,
const char** indev_type,
993 int* indev_id,
int num_in,
994 const int64_t** outshapes,
int* outdims,
995 void** outdata,
int* outtypes,
996 size_t* outIDs,
const char** outdev_type,
997 int* outdev_id,
int num_out,
999 xpu_malloc_t gpu_malloc,
void* gpu_alloc,
void* cuda_stream,
1001 int* instypes,
int* outstypes,
1002 void** in_indices,
void** out_indices,
1003 void** in_indptr,
void** out_indptr,
1004 int64_t* in_indices_shapes, int64_t* out_indices_shapes,
1005 int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
1006 void* rng_cpu_states,
void* rng_gpu_states);
1008 #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs" 1010 const char*
const* vals,
int num,
1011 int** mutate_indices,
int* indices_size);
1013 #define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState" 1015 const char*
const* vals,
int num,
const char* dev_type,
1016 int dev_id,
unsigned int** inshapes,
int* indims,
1017 int num_in,
const int* intypes,
void** state_op);
1019 #define MXLIB_OPCALLDESTROYOPSTATE_STR "_opCallDestroyOpState" 1022 #define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute" 1024 const int64_t** inshapes,
int* indims,
1025 void** indata,
int* intypes,
1026 size_t* inIDs,
const char** indev_type,
1027 int* indev_id,
int num_in,
1028 const int64_t** outshapes,
int* outdims,
1029 void** outdata,
int* outtypes,
1030 size_t* outIDs,
const char** outdev_type,
1031 int* outdev_id,
int num_out,
1033 xpu_malloc_t gpu_malloc,
void* gpu_alloc,
void* stream,
1035 int* instypes,
int* outstypes,
1036 void** in_indices,
void** out_indices,
1037 void** in_indptr,
void** out_indptr,
1038 int64_t* in_indices_shapes, int64_t* out_indices_shapes,
1039 int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
1040 void* rng_cpu_states,
void* rng_gpu_states);
1042 #define MXLIB_PARTREGSIZE_STR "_partRegSize" 1045 #define MXLIB_PARTREGGETCOUNT_STR "_partRegGetCount" 1048 #define MXLIB_PARTREGGET_STR "_partRegGet" 1049 typedef void (*
partRegGet_t)(
int part_idx,
int stg_idx,
const char** strategy,
1053 #define MXLIB_PARTCALLSUPPORTEDOPS_STR "_partCallSupportedOps" 1055 int num_ids,
int *ids,
const char*
const* opt_keys,
1056 const char*
const* opt_vals,
int num_opts);
1058 #define MXLIB_PARTCALLCREATESELECTOR_STR "_partCallCreateSelector" 1060 void** selector,
const char*
const* opt_keys,
1061 const char*
const* opt_vals,
int num_opts);
1063 #define MXLIB_PARTCALLSELECT_STR "_partCallSelect" 1066 #define MXLIB_PARTCALLSELECTINPUT_STR "_partCallSelectInput" 1070 #define MXLIB_PARTCALLSELECTOUTPUT_STR "_partCallSelectOutput" 1074 #define MXLIB_PARTCALLFILTER_STR "_partCallFilter" 1076 int** keep,
int* num_keep);
1078 #define MXLIB_PARTCALLRESET_STR "_partCallReset" 1081 #define MXLIB_PARTCALLREVIEWSUBGRAPH_STR "_partCallReviewSubgraph" 1083 int subgraph_id,
int *accept,
const char*
const* opt_keys,
1084 const char*
const* opt_vals,
int num_opts,
1085 char*** attr_keys,
char*** attr_vals,
int *num_attrs,
1086 const char*
const* arg_names,
int num_args,
1087 void*
const* arg_data,
const int64_t*
const* arg_shapes,
1088 const int* arg_dims,
const int* arg_types,
1089 const size_t* arg_IDs,
const char*
const* arg_dev_type,
1090 const int* arg_dev_id,
1091 const char*
const* aux_names,
int num_aux,
1092 void*
const* aux_data,
const int64_t*
const* aux_shapes,
1093 const int* aux_dims,
const int* aux_types,
1094 const size_t* aux_IDs,
const char*
const* aux_dev_type,
1095 const int* aux_dev_id);
1097 #define MXLIB_PASSREGSIZE_STR "_passRegSize" 1100 #define MXLIB_PASSREGGET_STR "_passRegGet" 1103 #define MXLIB_PASSCALLGRAPHPASS_STR "_passCallGraphPass" 1105 char** out_graph,
const char*
const* opt_keys,
1106 const char*
const* opt_vals,
int num_opts,
1107 const char* pass_name,
const char*
const* arg_names,
1108 int num_args,
void*
const* arg_data,
1109 const int64_t*
const* arg_shapes,
const int* arg_dims,
1110 const int* arg_types,
const size_t* arg_IDs,
1111 const char*
const* arg_dev_type,
const int* arg_dev_id,
1112 const char*
const* aux_names,
int num_aux,
1113 void*
const* aux_data,
const int64_t*
const* aux_shapes,
1114 const int* aux_dims,
const int* aux_types,
1115 const size_t* aux_IDs,
const char*
const* aux_dev_type,
1117 const void* nd_alloc);
1119 #define MXLIB_INITIALIZE_STR "initialize" 1122 #define MXLIB_OPVERSION_STR "_opVersion" 1125 #define MXLIB_MSGSIZE_STR "_msgSize" 1128 #define MXLIB_MSGGET_STR "_msgGet" 1136 : instance(inst), destroy_(destroy) {}
1143 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) 1144 #define MX_INT_RET __declspec(dllexport) int __cdecl 1145 #define MX_VOID_RET __declspec(dllexport) void __cdecl 1147 #define MX_INT_RET int 1148 #define MX_VOID_RET void 1164 int* forward_count,
const char*** backward_ctx,
1176 const char*
const* vals,
int num,
1177 int* num_in,
int* num_out);
1181 const char*
const* vals,
int num,
1182 unsigned int** inshapes,
int* indims,
int num_in,
1183 unsigned int*** mod_inshapes,
int** mod_indims,
1184 unsigned int*** outshapes,
int** outdims,
int num_out);
1188 const char*
const* vals,
int num,
1189 int* intypes,
int num_in,
int* outtypes,
int num_out);
1193 const char*
const* vals,
int num,
1194 int* instypes,
int num_in,
int* outstypes,
int num_out);
1198 const char*
const* vals,
1199 int num,
const int64_t** inshapes,
int* indims,
void** indata,
1200 int* intypes,
size_t* inIDs,
const char** indev_type,
int* indev_id,
1201 int num_in,
const int64_t** outshapes,
int* outdims,
void** outdata,
1202 int* outtypes,
size_t* outIDs,
const char** outdev_type,
1208 int* instypes,
int* outstypes,
void** in_indices,
void** out_indices,
1209 void** in_indptr,
void** out_indptr,
1210 int64_t* in_indices_shapes, int64_t* out_indices_shapes,
1211 int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
1212 void* rng_cpu_states,
void* rng_gpu_states);
1216 const char*
const* vals,
int num,
1217 int** mutate_indices,
int* indices_size);
1221 const char*
const* vals,
int num,
const char* dev_type,
1222 int dev_id,
unsigned int** inshapes,
int* indims,
1223 int num_in,
const int* intypes,
void** state_op);
1230 int* indims,
void** indata,
int* intypes,
size_t* inIDs,
1231 const char** indev_type,
int* indev_id,
int num_in,
1232 const int64_t** outshapes,
int* outdims,
void** outdata,
1233 int* outtypes,
size_t* outIDs,
const char** outdev_type,
1234 int* outdev_id,
int num_out,
1239 void* sparse_alloc,
int* instypes,
int* outstypes,
1240 void** in_indices,
void** out_indices,
void** in_indptr,
1241 void** out_indptr, int64_t* in_indices_shapes,
1242 int64_t* out_indices_shapes, int64_t* in_indptr_shapes,
1243 int64_t* out_indptr_shapes,
1244 void* rng_cpu_states,
void* rng_gpu_states);
1261 int num_ids,
int *ids,
const char*
const* opt_keys,
1262 const char*
const* opt_vals,
int num_opts);
1266 void** selector,
const char*
const* opt_keys,
1267 const char*
const* opt_vals,
int num_opts);
1274 int input_nodeID,
int* selected);
1278 int output_nodeID,
int* selected);
1282 int** keep,
int* num_keep);
1289 int subgraph_id,
int *accept,
const char*
const* opt_keys,
1290 const char*
const* opt_vals,
int num_opts,
1291 char*** attr_keys,
char*** attr_vals,
int *num_attrs,
1292 const char*
const* arg_names,
int num_args,
1293 void*
const* arg_data,
const int64_t*
const* arg_shapes,
1294 const int* arg_dims,
const int* arg_types,
1295 const size_t* arg_IDs,
const char*
const* arg_dev_type,
1296 const int* arg_dev_id,
1297 const char*
const* aux_names,
int num_aux,
1298 void*
const* aux_data,
const int64_t*
const* aux_shapes,
1299 const int* aux_dims,
const int* aux_types,
1300 const size_t* aux_IDs,
const char*
const* aux_dev_type,
1301 const int* aux_dev_id);
1308 const char** pass_name);
1312 char** out_graph,
const char*
const* opt_keys,
1313 const char*
const* opt_vals,
int num_opts,
1314 const char* pass_name,
const char*
const* arg_names,
int num_args,
1315 void*
const* arg_data,
const int64_t*
const* arg_shapes,
1316 const int* arg_dims,
const int* arg_types,
1317 const size_t* arg_IDs,
const char*
const* arg_dev_type,
1318 const int* arg_dev_id,
const char*
const* aux_names,
int num_aux,
1319 void*
const* aux_data,
const int64_t*
const* aux_shapes,
1320 const int* aux_dims,
const int* aux_types,
1321 const size_t* aux_IDs,
const char*
const* aux_dev_type,
1323 const void* nd_alloc);
1332 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) 1345 #endif // MXNET_LIB_API_H_ An abstract class for subgraph property.
Definition: lib_api.h:836
int(* opVersion_t)()
Definition: lib_api.h:1123
int(* partCallSupportedOps_t)(supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1054
const char * name
partitioner name
Definition: lib_api.h:858
MXReturnValue(* inferShape_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< std::vector< unsigned int > > *in_shapes, std::vector< std::vector< unsigned int > > *out_shapes)
Definition: lib_api.h:733
DLDeviceType
The device type in DLContext.
Definition: lib_api.h:101
std::string getShapeAt(const std::string &shape, unsigned index)
int(* opCallParseAttrs_t)(parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
Definition: lib_api.h:966
MX_VOID_RET _opRegGet(int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate)
returns operator registration at specified index
virtual MXReturnValue Backward(std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &op_res)
Definition: lib_api.h:703
int(* msgGet_t)(int idx, const char **msg)
Definition: lib_api.h:1129
const char * name
operator name
Definition: lib_api.h:775
Definition: lib_api.h:257
MXTensor * tensor
Definition: lib_api.h:571
virtual void Filter(const std::vector< int > &candidates, std::vector< int > *keep)
Definition: lib_api.h:671
CustomStatefulOpWrapper(CustomStatefulOp *inst, opCallDestroyOpState_t destroy)
Definition: lib_api.h:1135
provide resource APIs memory allocation mechanism to Forward/Backward functions
Definition: lib_api.h:417
Definition: lib_api.h:255
void(* partCallReset_t)(void *sel_inst)
Definition: lib_api.h:1079
std::vector< const char * > strategies
strategy names
Definition: lib_api.h:863
OpenCL devices.
Definition: lib_api.h:112
std::vector< NodeEntry > outputs
Definition: lib_api.h:573
DLDataTypeCode
The type code options DLDataType.
Definition: lib_api.h:142
MX_VOID_RET _partCallSelectOutput(void *sel_inst, int nodeID, int output_nodeID, int *selected)
returns status of calling select output function from library
StatefulOp wrapper class to pass to backend OpState.
Definition: lib_api.h:1132
MX_INT_RET _opRegSize()
returns number of ops registered in this library
T & add(const char *name)
add a new entry
Definition: lib_api.h:887
Metal for Apple GPU.
Definition: lib_api.h:116
Definition: lib_api.h:296
int(* passCallGraphPass_t)(graphPass_t graphPass, const char *in_graph, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, nd_malloc_t nd_malloc, const void *nd_alloc)
Definition: lib_api.h:1104
namespace of mxnet
Definition: api_registry.h:33
Definition: lib_api.h:145
void(* partCallSelectInput_t)(void *sel_inst, int nodeID, int input_nodeID, int *selected)
Definition: lib_api.h:1067
#define MX_VOID_RET
Definition: lib_api.h:1148
ROCm GPUs for AMD GPUs.
Definition: lib_api.h:120
A Device context for Tensor and operator.
Definition: dlpack.h:69
std::vector< Graph * > subgraphs
Definition: lib_api.h:574
CUDA GPU device.
Definition: lib_api.h:105
Node * node
Definition: lib_api.h:549
bool ignore_warn
Definition: lib_api.h:710
void(* partRegGet_t)(int part_idx, int stg_idx, const char **strategy, supportedOps_t *supportedOps, createSelector_t *createSelector, reviewSubgraph_t *reviewSubgraph, const char **op_name)
Definition: lib_api.h:1049
MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
returns status of calling mutateInputs function for operator from library
int64_t * indices
Definition: lib_api.h:305
int64_t data_len
Definition: lib_api.h:300
MXStorageType
Definition: lib_api.h:263
Reserved extension device type, used for quickly test extension device The semantics can differ depen...
Definition: lib_api.h:126
Pinned CUDA GPU device by cudaMallocHost.
Definition: lib_api.h:110
inferSType_t infer_storage_type
Definition: lib_api.h:780
MXContext ctx
Definition: lib_api.h:359
#define PRIVATE_SYMBOL
For loading multiple custom op libraries in Linux, exporting same symbol multiple times may lead to u...
Definition: lib_api.h:65
MX_INT_RET _partRegSize()
returns number of partitioners registered in this library
Definition: lib_api.h:495
MXReturnValue(* graphPass_t)(mxnet::ext::Graph *graph, const std::unordered_map< std::string, std::string > &options)
Custom Pass Create function template.
Definition: lib_api.h:799
MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling create selector function from library
MX_VOID_RET _partCallSelect(void *sel_inst, int nodeID, int *selected)
returns status of calling select function from library
MX_VOID_RET _opCallDestroyOpState(void *state_op)
returns status of deleting StatefulOp instance for operator from library
Definition: lib_api.h:254
int num
Definition: lib_api.h:535
Tensor data structure used by custom operator.
Definition: lib_api.h:320
MX_VOID_RET _partCallFilter(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
returns status of calling filter function from library
std::unordered_map< std::string, std::string > attrs
Definition: lib_api.h:575
void *(* xpu_malloc_t)(void *, int)
resource malloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:370
int(* opCallInferShape_t)(inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
Definition: lib_api.h:971
MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *instypes, int num_in, int *outstypes, int num_out)
returns status of calling inferSType function for operator from library
MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling supported ops function from library
MX_VOID_RET _partCallSelectInput(void *sel_inst, int nodeID, int input_nodeID, int *selected)
returns status of calling select input function from library
MXReturnValue(* createOpState_t)(const std::unordered_map< std::string, std::string > &attributes, const MXContext &ctx, const std::vector< std::vector< unsigned int > > &in_shapes, const std::vector< int > in_types, CustomStatefulOp **)
Definition: lib_api.h:740
MXReturnValue
Definition: lib_api.h:290
void(* partCallSelectOutput_t)(void *sel_inst, int nodeID, int output_nodeID, int *selected)
Definition: lib_api.h:1071
std::string name
Definition: lib_api.h:570
MX_INT_RET _opCallFStatefulCompute(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Stateful Forward/Backward for operator from library
CustomStatefulOp * get_instance()
Definition: lib_api.h:1137
std::map< std::string, createSelector_t > selector_map
Definition: lib_api.h:860
std::vector< NodeEntry > inputs
Definition: lib_api.h:572
Definition: lib_api.h:253
graphPass_t pass
pass function
Definition: lib_api.h:815
Definition: lib_api.h:256
std::string str
Definition: lib_api.h:536
MXReturnValue(* createSelector_t)(const mxnet::ext::Graph *graph, CustomOpSelector **sel_inst, const std::unordered_map< std::string, std::string > &options)
Definition: lib_api.h:822
MXDType dtype
Definition: lib_api.h:353
int(* opCallMutateInputs_t)(mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
Definition: lib_api.h:1009
Definition: lib_api.h:143
inferType_t infer_type
Definition: lib_api.h:779
std::map< std::string, JsonVal > attrs
Definition: lib_api.h:638
void(* partCallSelect_t)(void *sel_inst, int nodeID, int *selected)
Definition: lib_api.h:1064
int entry
Definition: lib_api.h:550
Definition: lib_api.h:250
Definition: lib_api.h:220
Class to hold custom operator registration.
Definition: lib_api.h:750
MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
returns status of calling inferType function for operator from library
Definition: lib_api.h:582
int(* opRegGet_t)(int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate)
Definition: lib_api.h:953
Vulkan buffer for next generation graphics.
Definition: lib_api.h:114
MX_INT_RET _opVersion()
returns MXNet library version
inferShape_t infer_shape
Definition: lib_api.h:781
#define MX_ERROR_MSG
Definition: lib_api.h:244
int64_t indices_len
Definition: lib_api.h:306
MXReturnValue(* inferSType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_storage_types, std::vector< int > *out_storage_types)
Definition: lib_api.h:729
std::vector< const char * > op_names
subgraph operator name
Definition: lib_api.h:865
MXReturnValue(* reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id, bool *accept, const std::unordered_map< std::string, std::string > &options, std::unordered_map< std::string, std::string > *attrs)
Definition: lib_api.h:826
DLDeviceType
The device type in DLContext.
Definition: dlpack.h:38
MX_INT_RET _partRegGetCount(int idx, const char **name)
void * mx_gpu_rand_t
Definition: lib_api.h:383
Verilog simulator buffer.
Definition: lib_api.h:118
mx_stream_t get_cuda_stream() const
return the cuda stream object with correct type
Definition: lib_api.h:431
std::vector< createOpState_t > create_op_fp
Definition: lib_api.h:788
MXReturnValue(* fcomp_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &res)
Custom Operator function templates.
Definition: lib_api.h:717
mxnet::ext::MXReturnValue initialize(int version)
Checks if the MXNet version is supported by the library. If supported, initializes the library...
JsonType
Json utility to parse serialized subgraph symbol.
Definition: lib_api.h:495
MX_VOID_RET _opCallFree(void *ptr)
calls free from the external library for library allocated arrays
MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
returns status of calling review subgraph function from library
int(* opCallInferSType_t)(inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:983
MX_INT_RET _passRegSize()
returns number of graph passes registered in this library
Definition: lib_api.h:144
int(* passRegSize_t)(void)
Definition: lib_api.h:1098
definition of JSON objects
Definition: lib_api.h:498
Definition: lib_api.h:548
MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *json, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, mxnet::ext::nd_malloc_t nd_malloc, const void *nd_alloc)
returns status of calling graph pass function from library
int(* partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
Definition: lib_api.h:1082
std::string getDtypeAt(const std::string &dtype, unsigned index)
bool isSGop
Definition: lib_api.h:783
MXReturnValue(* inferType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_types, std::vector< int > *out_types)
Definition: lib_api.h:725
int(* opCallFree_t)(void *ptr)
Definition: lib_api.h:963
int(* opCallFStatefulComp_t)(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:1023
size_t verID
Definition: lib_api.h:356
std::vector< JsonVal > list
Definition: lib_api.h:537
MXReturnValue(* supportedOps_t)(const mxnet::ext::Graph *graph, std::vector< int > *ids, const std::unordered_map< std::string, std::string > &options)
Custom Subgraph Create function template.
Definition: lib_api.h:819
#define MX_INT_RET
Definition: lib_api.h:1147
const char * name
pass name
Definition: lib_api.h:813
parseAttrs_t parse_attrs
operator functions
Definition: lib_api.h:778
Definition: lib_api.h:393
data_type * data()
helper function to cast data pointer
Definition: lib_api.h:335
MXReturnValue(* parseAttrs_t)(const std::unordered_map< std::string, std::string > &attributes, int *num_inputs, int *num_outputs)
Definition: lib_api.h:722
mutateInputs_t mutate_inputs
Definition: lib_api.h:782
bool wasCreated()
Definition: lib_api.h:698
Context info passing from MXNet OpContext dev_type is string repr of supported context, currently only "cpu" and "gpu" dev_id is the device index where the tensor locates.
Definition: lib_api.h:277
std::mt19937 mx_cpu_rand_t
Definition: lib_api.h:385
MX_VOID_RET _msgGet(int idx, const char **msg)
returns operator registration at specified index
MX_VOID_RET _partCallReset(void *sel_inst)
returns status of calling reset selector function from library
Definition: lib_api.h:291
static CustomStatefulOp * create(Ts...args)
Definition: lib_api.h:692
void * mx_stream_t
GPU stream pointer, is void* when not compiled with CUDA.
Definition: lib_api.h:382
An abstract class for library authors creating stateful op custom library should override Forward and...
Definition: lib_api.h:686
int(* partCallCreateSelector_t)(createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1059
std::map< std::string, supportedOps_t > supported_map
Definition: lib_api.h:859
int size()
Definition: lib_api.h:892
Definition: lib_api.h:292
std::vector< const char * > forward_ctx_cstr
vector repr of ctx map to be easily loaded from c_api
Definition: lib_api.h:786
CPU device.
Definition: lib_api.h:103
int(* opCallDestroyOpState_t)(void *state_op)
Definition: lib_api.h:1020
int(* partRegSize_t)(void)
Definition: lib_api.h:1043
Definition: lib_api.h:554
void(* passRegGet_t)(int pass_idx, graphPass_t *graphPass, const char **pass_name)
Definition: lib_api.h:1101
Definition: lib_api.h:495
void(* sparse_malloc_t)(void *, int, int, int, void **, int64_t **, int64_t **)
sparse alloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:372
MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t *graphPass, const char **pass_name)
returns pass registration at specified index
void(* partCallFilter_t)(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
Definition: lib_api.h:1075
std::vector< NodeEntry > outputs
Definition: lib_api.h:637
int64_t indptr_len
Definition: lib_api.h:311
int dev_id
Definition: lib_api.h:287
std::vector< Node * > inputs
Definition: lib_api.h:636
DLTensor dltensor
Definition: lib_api.h:363
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op)
returns status of calling createStatefulOp function for operator from library
int(* opCallInferType_t)(inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:978
std::string dev_type
Definition: lib_api.h:286
int(* opCallCreateOpState_t)(createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op)
Definition: lib_api.h:1014
MXDType
Tensor data type, consistent with mshadow data type.
Definition: lib_api.h:249
int(* partRegGetCount_t)(int idx, const char **name)
Definition: lib_api.h:1046
std::vector< fcomp_t > forward_fp
Definition: lib_api.h:787
Definition: lib_api.h:495
MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
returns status of calling inferShape function for operator from library
int(* initialize_t)(int version)
Definition: lib_api.h:1120
std::map< JsonVal, JsonVal > map
Definition: lib_api.h:538
Definition: lib_api.h:269
int(* opRegSize_t)(void)
Definition: lib_api.h:950
An abstract class for graph passes.
Definition: lib_api.h:805
std::vector< int64_t > shape
Definition: lib_api.h:350
Registry class to registers things (ops, properties) Singleton class.
Definition: lib_api.h:873
JsonType type
Definition: lib_api.h:534
Definition: lib_api.h:267
The data type the tensor can hold.
Definition: dlpack.h:94
void(* nd_malloc_t)(const void *_ndarray_alloc, const int64_t *shapes, int num_shapes, const char *dev_str, int dev_id, int dtype, const char *name, int isArg, void **data)
resource malloc function to allocate ndarrays for graph passes
Definition: lib_api.h:374
Plain C Tensor object, does not manage memory.
Definition: dlpack.h:112
MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Forward/Backward function for operator from library
Definition: lib_api.h:648
std::stringstream & add(const char *file, int line)
virtual void Reset()
Definition: lib_api.h:678
Definition: lib_api.h:265
std::string op
Definition: lib_api.h:569
Definition: lib_api.h:251
Definition: lib_api.h:252
Definition: lib_api.h:495
Definition: lib_api.h:495
MXReturnValue(* mutateInputs_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *input_indices)
Definition: lib_api.h:737
int(* msgSize_t)(void)
Definition: lib_api.h:1126
mx_gpu_rand_t * get_gpu_rand_states() const
get pointer to initialized and seeded random number states located on GPU
Definition: lib_api.h:445
std::map< std::string, reviewSubgraph_t > review_map
Definition: lib_api.h:861
MXStorageType stype
Definition: lib_api.h:366
int(* opCallFComp_t)(fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:988
MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char **strategy, mxnet::ext::supportedOps_t *supportedOps, mxnet::ext::createSelector_t *createSelector, mxnet::ext::reviewSubgraph_t *reviewSubgraph, const char **op_name)
returns partitioner registration at specified index
MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
returns status of calling parse attributes function for operator from library
void * data_ptr
Definition: lib_api.h:347