mxnet
lib_api.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
32 #ifndef MXNET_LIB_API_H_
33 #define MXNET_LIB_API_H_
34 
35 #include <stdint.h>
36 #include <stdlib.h>
37 #include <string.h>
38 #include <vector>
39 #include <map>
40 #include <unordered_set>
41 #include <unordered_map>
42 #include <string>
43 #include <iostream>
44 #include <utility>
45 #include <stdexcept>
46 #include <functional>
47 #include <random>
48 #include <sstream>
49 
50 #if defined(__NVCC__)
51  #include <cuda_runtime.h>
52  #include <curand_kernel.h>
53 #endif
54 
55 /* Make sure to update the version number everytime you make changes */
56 #define MX_LIBRARY_VERSION 10
57 
63 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
64  #define PRIVATE_SYMBOL
65 #else
66  #define PRIVATE_SYMBOL __attribute__ ((visibility ("hidden")))
67 #endif
68 
69 /*
70  * Import from DLPack https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
71  */
72 #ifndef DLPACK_VERSION
73 #ifdef __cplusplus
74 #define DLPACK_EXTERN_C extern "C"
75 #else
76 #define DLPACK_EXTERN_C
77 #endif
78 
80 #define DLPACK_VERSION 020
81 
83 #ifdef _WIN32
84 #ifdef DLPACK_EXPORTS
85 #define DLPACK_DLL __declspec(dllexport)
86 #else
87 #define DLPACK_DLL __declspec(dllimport)
88 #endif
89 #else
90 #define DLPACK_DLL
91 #endif
92 
93 #include <stdint.h>
94 #include <stddef.h>
95 
96 #ifdef __cplusplus
97 extern "C" {
98  #endif
99 
102  typedef enum {
104  kDLCPU = 1,
106  kDLGPU = 2,
117  kDLMetal = 8,
119  kDLVPI = 9,
121  kDLROCM = 10,
127  kDLExtDev = 12,
128  } DLDeviceType;
129 
133  typedef struct {
135  DLDeviceType device_type;
137  int device_id;
138  } DLContext;
139 
143  typedef enum {
144  kDLInt = 0U,
145  kDLUInt = 1U,
146  kDLFloat = 2U,
147  } DLDataTypeCode;
148 
157  typedef struct {
163  uint8_t code;
167  uint8_t bits;
169  uint16_t lanes;
170  } DLDataType;
171 
175  typedef struct {
195  void* data;
197  DLContext ctx;
199  int ndim;
201  DLDataType dtype;
203  int64_t* shape;
208  int64_t* strides;
210  uint64_t byte_offset;
211  } DLTensor;
212 #ifdef __cplusplus
213 } // DLPACK_EXTERN_C
214 #endif
215 #endif
216 
217 namespace mxnet {
218 namespace ext {
219 
220 /* \brief Class to store error messages from extensions to pass to MXNet */
221 class MXerrorMsgs {
222  public:
223  /* \brief get singleton pointer to class */
224  static MXerrorMsgs* get();
225 
226  /* \brief add a new error message */
227  std::stringstream& add(const char* file, int line);
228 
229  /* \brief return number of error messages */
230  int size();
231 
232  /* \brief get error message at index */
233  const std::string* get(int idx);
234 
235  private:
237  MXerrorMsgs() {}
239  ~MXerrorMsgs();
241  std::vector<std::stringstream*> messages;
242 };
243 
244 // Add a new error message, example: MX_ERROR_MSG << "my error msg";
245 #define MX_ERROR_MSG mxnet::ext::MXerrorMsgs::get()->add(__FILE__, __LINE__)
246 
250 enum MXDType {
251  kFloat32 = 0,
252  kFloat64 = 1,
253  kFloat16 = 2,
254  kUint8 = 3,
255  kInt32 = 4,
256  kInt8 = 5,
257  kInt64 = 6,
258  kUNSET = 100,
259 };
260 
261 /*
262  * MXTensor storage type.
263  */
265  // dense
267  // row sparse
269  // csr
271 };
272 
278 struct MXContext {
279  MXContext();
280  explicit MXContext(std::string dev_type_, int dev_id_);
281  explicit MXContext(const char* dev_type_, int dev_id_);
282  static MXContext CPU();
283  static MXContext GPU();
284  static MXContext CPU(int dev_id);
285  static MXContext GPU(int dev_id);
286 
287  std::string dev_type;
288  int dev_id;
289 };
290 
292  MX_FAIL = 0,
294 };
295 
296 // For sparse tensors, read/write the data from NDarray via pointers.
297 struct MXSparse {
298  // Pointer to data.
299  void *data{nullptr};
300  // length of (non-zero) data.
301  int64_t data_len;
302 
303  // To store aux data for sparse.
304  // For CSR, indices stores the col index of non-zero elements.
305  // For row sparse, indices store row index of rows which have non-zero elements.
306  int64_t* indices;
307  int64_t indices_len;
308 
309  // For CSR, indptr gives the start and end index of data for each row.
310  // For row sparse, indptr is not used.
311  int64_t* indptr = nullptr;
312  int64_t indptr_len;
313 
314  void set(void *data_ptr, const int64_t* dims, int ndims, void *idx,
315  int64_t num_idx, void *idx_ptr = nullptr, int64_t num_idx_ptr = 0);
316 };
317 
321 struct MXTensor {
322  MXTensor();
323  MXTensor(const MXTensor& oth);
324  MXTensor(void *data_ptr, std::vector<int64_t> shape, MXDType dtype,
325  size_t vID, MXContext mx_ctx, MXStorageType stype = kDefaultStorage);
326 
328  void setTensor(void *dptr, MXDType type, const int64_t* dims, int ndims,
329  size_t vID, MXContext mx_ctx, MXStorageType storage_type);
330 
332  void setDLTensor();
333 
335  template<typename data_type>
336  inline data_type* data() {
337  return reinterpret_cast<data_type*>(data_ptr);
338  }
339 
341  int64_t size() const;
342 
344  bool isSame(const MXTensor &oth) const;
345 
346  // For dense, data_ptr points to 1D flattened tensor data
347  // For sparse, data_ptr points to MXSparse
348  void *data_ptr;
349 
350  // shape is in [2,3,4] format to represent high-dim tensor
351  std::vector<int64_t> shape;
352 
353  // type can only be MXDType enum types
355 
356  // version number updated if the tensor has changed since the last use by custom op
357  size_t verID;
358 
359  // context of MXTensor representing which device the tensor data is located
361 
362  // corresponding DLTensor repr of MXTensor
363  // easy way to reuse functions taking DLTensor
365 
366  // storage type
368 };
369 
371 typedef void* (*xpu_malloc_t)(void*, int);
373 typedef void (*sparse_malloc_t)(void*, int, int, int, void**, int64_t**, int64_t**);
375 typedef void (*nd_malloc_t)(const void* _ndarray_alloc, const int64_t* shapes, int num_shapes,
376  const char* dev_str, int dev_id, int dtype, const char* name,
377  int isArg, void** data);
379 #if defined(__NVCC__)
380  typedef cudaStream_t mx_stream_t;
381  typedef curandStatePhilox4_32_10_t mx_gpu_rand_t;
382 #else
383  typedef void* mx_stream_t;
384  typedef void* mx_gpu_rand_t;
385 #endif
386 typedef std::mt19937 mx_cpu_rand_t;
387 
389 /* Each thread should generate random number unique sequence out of different states */
390 #define MX_NUM_CPU_RANDOM_STATES 1024
391 #define MX_NUM_GPU_RANDOM_STATES 32768
392 
393 /* \brief Class to help allocate new args/aux params in graph passes */
395  public:
396  PassResource(std::unordered_map<std::string, MXTensor>* new_args,
397  std::unordered_map<std::string, MXTensor>* new_aux,
398  nd_malloc_t nd_malloc, const void* nd_alloc);
399 
400  // allocate new arg param, adds to args map, returns newly allocated tensor
401  MXTensor* alloc_arg(const std::string& name, const std::vector<int64_t>& shapes,
402  const MXContext &ctx, MXDType dtype) const;
403 
404  // allocate new aux param, adds to aux map, returns newly allocated tensor
405  MXTensor* alloc_aux(const std::string& name, const std::vector<int64_t>& shapes,
406  const MXContext &ctx, MXDType dtype) const;
407 
408  private:
409  std::unordered_map<std::string, MXTensor>* new_args_;
410  std::unordered_map<std::string, MXTensor>* new_aux_;
411  nd_malloc_t nd_malloc_;
412  const void* nd_alloc_;
413 };
414 
418 class OpResource {
419  public:
420  OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp,
421  xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream,
422  sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp,
423  void* rng_cpu_states, void* rng_gpu_states);
424 
426  void* alloc_cpu(int size) const;
427 
429  void* alloc_gpu(int size) const;
430 
432  inline mx_stream_t get_cuda_stream() const {
433  return static_cast<mx_stream_t>(cuda_stream);
434  }
435 
437  void alloc_sparse(MXSparse* sparse, int index, int indices_len, int indptr_len = 0) const;
438 
440  /* Access each state by states[id], but this id should be <= MX_NUM_CPU_RANDOM_STATES */
441  mx_cpu_rand_t* get_cpu_rand_states() const;
442 
444  /* Access each state by states[id], but this id should be <= MX_NUM_GPU_RANDOM_STATES */
445  /* Note that if you are using cpu build, it will return a nullptr */
446  inline mx_gpu_rand_t* get_gpu_rand_states() const {
447  return static_cast<mx_gpu_rand_t*>(rand_gpu_states);
448  }
449 
450  private:
452  xpu_malloc_t cpu_malloc, gpu_malloc;
454  void *cpu_alloc, *gpu_alloc;
456  void *cuda_stream;
458  sparse_malloc_t sparse_malloc;
460  void *sparse_alloc;
462  void *rand_cpu_states, *rand_gpu_states;
463 };
464 
466 #define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json"
467 
468 #define MX_STR_DTYPE "__ext_dtype__"
469 
470 #define MX_STR_SHAPE "__ext_shape__"
471 
472 #define MX_STR_EXTRA_INPUTS "__ext_extra_inputs__"
473 
474 /* \brief get shape value from list of shapes string
475  *
476  * Examples:
477  *
478  * getShapeAt("[[1]]", 0) returns "[1]"
479  * getShapeAt("[[1],[2,3]]", 1) returns "[2,3]"
480  */
481 std::string getShapeAt(const std::string& shape, unsigned index);
482 
483 /* \brief get dtype value from list of dtypes string
484  *
485  * Examples:
486  *
487  * getDtypeAt("[1]", 0) returns "1"
488  * getDtypeAt("[1,2]", 1) returns "2"
489  */
490 std::string getDtypeAt(const std::string& dtype, unsigned index);
491 
496 enum JsonType {ERR, STR, NUM, LIST, MAP};
497 
499 struct JsonVal {
500  JsonVal(); // default constructor
501  // construct a JSON object by type
502  explicit JsonVal(JsonType t);
503  // construct a string JSON object
504  explicit JsonVal(std::string s);
505  // construct a number JSON object
506  explicit JsonVal(int n);
507  // complex constructor
508  JsonVal(JsonType t, int n, std::string s);
509  bool operator<(const JsonVal &o) const;
510 
511  // convert JSON object back to JSON-compatible string
512  std::string dump() const;
513 
514  // convert JSON-compatible string to JSON object
515  static JsonVal parse(const std::string& json);
516 
517  // parse a string JSON object
518  static JsonVal parse_string(const std::string& json, unsigned int* idx);
519 
520  // parse a number JSON object
521  static JsonVal parse_num(const std::string& json, unsigned int* idx);
522 
523  // parse a list of JSON objects
524  static JsonVal parse_list(const std::string& json, unsigned int* idx);
525 
526  // parse a map of JSON objects
527  static JsonVal parse_map(const std::string& json, unsigned int* idx);
528 
529  // generic parse function
530  static JsonVal parse(const std::string& json, unsigned int *idx);
531 
532  // debug function to convert data structure to a debugstring
533  std::string toString() const;
534 
536  int num;
537  std::string str;
538  std::vector<JsonVal> list;
539  std::map<JsonVal, JsonVal> map;
540 };
541 
545 class Node;
546 class Graph;
547 
548 // Representation of an input/output to a node
549 struct NodeEntry {
550  Node* node; // other node thats producing/consuming inputs/outputs
551  int entry; // entry index from other node (ie. output index from producing node)
552 };
553 
554 // Representation of a node in the graph
555 class Node {
556  public:
557  Node();
558 
559  // internally set passResource to enable tensor allocation for graph passes
560  void _setPassResource(PassResource* res_);
561 
562  /* \brief allocate an arg tensor for this node */
563  void alloc_arg(const std::vector<int64_t>& shapes,
564  const MXContext &ctx, MXDType dtype);
565 
566  /* \brief allocate an aux tensor for this node */
567  void alloc_aux(const std::vector<int64_t>& shapes,
568  const MXContext &ctx, MXDType dtype);
569 
570  std::string op; // operator name (ie. Convolution)
571  std::string name; // unique node name (ie. conv_0 or conv_1)
572  MXTensor* tensor; // tensor data for input nodes
573  std::vector<NodeEntry> inputs; // set of inputs to the node
574  std::vector<NodeEntry> outputs; // set of outputs from the node
575  std::vector<Graph*> subgraphs; // set of subgraphs within this node
576  std::unordered_map<std::string, std::string> attrs; // node attributes
577 
578  private:
579  PassResource* res;
580 };
581 
582 // Representation of the graph
583 class Graph {
584  public:
585  Graph();
586 
587  /* \brief deleted nodes when deleting the graph */
588  ~Graph();
589 
590  /* \brief create a graph object from an unparsed string */
591  static Graph* fromString(const std::string& json);
592 
593  /* \brief create a graph object from a parsed JSON object */
594  static Graph* fromJson(JsonVal val);
595 
596  /* \brief convert graph object back to JSON object */
597  JsonVal toJson() const;
598 
599  /* \brief convert graph object to JSON string */
600  std::string toString() const;
601 
602  /* \brief visits a node "n" */
603  void _dfs_util(Node* n, std::unordered_set<Node*>* to_visit,
604  std::function<void(Node*)> handler) const;
605 
606  /* \brief post-order DFS graph traversal */
607  void DFS(std::function<void(Node*)> handler) const;
608 
609  /* \brief sort graph nodes in topological order */
610  std::vector<Node*> topological_sort() const;
611 
612  /* \brief print out graph details */
613  void print(int indent = 0) const;
614 
615  /* \brief add a new node to this graph */
616  Node* addNode(const std::string& name, const std::string& op);
617 
618  /* \brief get node at index in graph */
619  Node* getNode(size_t idx);
620 
621  /* \brief get const node at index in const graph */
622  const Node* getNode(size_t idx) const;
623 
624  /* \brief get attribute on graph */
625  const JsonVal& getAttr(const std::string& key) const;
626 
627  /* \brief get number of nodes in the graph */
628  size_t size() const;
629 
630  // internally set passResource to enable tensor allocation for graph passes
631  void _setPassResource(PassResource* res_);
632 
633  // internally set arg/aux params when available
634  void _setParams(std::unordered_map<std::string, mxnet::ext::MXTensor>* args,
635  std::unordered_map<std::string, mxnet::ext::MXTensor>* aux);
636 
637  std::vector<Node*> inputs;
638  std::vector<NodeEntry> outputs;
639  std::map<std::string, JsonVal> attrs;
640 
641  private:
642  std::vector<Node*> nodes;
643  PassResource* res;
644 };
645 
646 /* \brief An abstract class for library authors creating custom
647  * partitioners. Optional, can just implement supportedOps instead
648  */
650  public:
651  /* \brief Select a node to include in subgraph, return true to include node
652  * nodeID - index of node in graph
653  */
654  virtual bool Select(int nodeID) = 0;
655  /* \brief Select an input node from current node to include in subgraph
656  * return true to include node
657  * nodeID - index of node in graph
658  * input_nodeID - index of input node in graph
659  */
660  virtual bool SelectInput(int nodeID, int input_nodeID) = 0;
661  /* \brief Select an output node from current node to include in subgraph
662  * return true to include node
663  * nodeID - index of node in graph
664  * output_nodeID - index of output node in graph
665  */
666  virtual bool SelectOutput(int nodeID, int output_nodeID) = 0;
667  /* \brief Review nodes to include in subgraph
668  * return set of candidate nodes to keep in subgraph
669  * candidates - indices of nodes to include in subgraph
670  * keep - indices of nodes to keep in subgraph
671  */
672  virtual void Filter(const std::vector<int>& candidates,
673  std::vector<int>* keep) {
674  keep->insert(keep->end(), candidates.begin(), candidates.end());
675  }
676  /* \brief Reset any selector state, called after growing subgraph, before filter
677  * Called after finished calling SelectInput/SelectOutput and growing subgraph
678  */
679  virtual void Reset() {}
680 };
681 
688  public:
689  virtual MXReturnValue Forward(std::vector<MXTensor>* inputs,
690  std::vector<MXTensor>* outputs,
691  const OpResource& op_res) = 0;
692  virtual MXReturnValue Backward(std::vector<MXTensor>* inputs,
693  std::vector<MXTensor>* outputs,
694  const OpResource& op_res) {
695  MX_ERROR_MSG << "Error! Operator does not support backward" << std::endl;
696  return MX_FAIL;
697  }
698 };
699 
702  public:
703  explicit CustomStatefulOpWrapper(CustomStatefulOp* inst) : instance(inst) {}
704  CustomStatefulOp* get_instance() { return instance; }
705  private:
706  CustomStatefulOp* instance;
707 };
708 
710 typedef MXReturnValue (*fcomp_t)(const std::unordered_map<std::string,
711  std::string>& attributes,
712  std::vector<MXTensor>* inputs,
713  std::vector<MXTensor>* outputs,
714  const OpResource& res);
715 typedef MXReturnValue (*parseAttrs_t)(const std::unordered_map<std::string,
716  std::string>& attributes,
717  int* num_inputs, int* num_outputs);
718 typedef MXReturnValue (*inferType_t)(const std::unordered_map<std::string,
719  std::string>& attributes,
720  std::vector<int>* in_types,
721  std::vector<int>* out_types);
722 typedef MXReturnValue (*inferSType_t)(const std::unordered_map<std::string,
723  std::string>& attributes,
724  std::vector<int>* in_storage_types,
725  std::vector<int>* out_storage_types);
726 typedef MXReturnValue (*inferShape_t)(const std::unordered_map<std::string,
727  std::string>& attributes,
728  std::vector<std::vector<unsigned int> >* in_shapes,
729  std::vector<std::vector<unsigned int> >* out_shapes);
730 typedef MXReturnValue (*mutateInputs_t)(const std::unordered_map<std::string,
731  std::string>& attributes,
732  std::vector<int>* input_indices);
733 typedef MXReturnValue (*createOpState_t)(const std::unordered_map<std::string,
734  std::string>& attributes,
735  const MXContext& ctx,
736  const std::vector<std::vector<unsigned int> >& in_shapes,
737  const std::vector<int> in_types,
738  CustomStatefulOp**);
739 
743 class CustomOp {
744  public:
745  explicit CustomOp(const char* op_name);
746 
747  CustomOp& setForward(fcomp_t fcomp, const char* ctx);
748 
749  CustomOp& setBackward(fcomp_t fgrad, const char* ctx);
750 
751  CustomOp& setParseAttrs(parseAttrs_t func);
752 
753  CustomOp& setInferType(inferType_t func);
754 
755  CustomOp& setInferSType(inferSType_t func);
756 
757  CustomOp& setInferShape(inferShape_t func);
758 
759  CustomOp& setMutateInputs(mutateInputs_t func);
760 
761  CustomOp& setCreateOpState(createOpState_t func, const char* ctx);
762 
763  CustomOp& setIsSubgraphOp();
764 
765  void mapToVector();
766 
768  const char* name;
769 
776  bool isSGop;
777 
779  std::vector<const char*> forward_ctx_cstr, backward_ctx_cstr, create_op_ctx_cstr;
780  std::vector<fcomp_t> forward_fp, backward_fp;
781  std::vector<createOpState_t> create_op_fp;
782 
783  private:
784  void raiseDuplicateContextError();
785 
787  std::unordered_map<const char*, fcomp_t> forward_ctx_map, backward_ctx_map;
788  std::unordered_map<const char*, createOpState_t> create_op_ctx_map;
789 };
790 
793  const std::unordered_map<std::string, std::string>& options);
794 
798 class CustomPass {
799  public:
800  CustomPass();
801  explicit CustomPass(const char* pass_name);
802 
803  CustomPass& setBody(graphPass_t fn);
804 
806  const char* name;
809 };
810 
812 typedef MXReturnValue (*supportedOps_t)(const mxnet::ext::Graph *graph, std::vector<int>* ids,
813  const std::unordered_map<std::string,
814  std::string>& options);
816  CustomOpSelector** sel_inst,
817  const std::unordered_map<std::string,
818  std::string>& options);
819 typedef MXReturnValue (*reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id,
820  bool* accept,
821  const std::unordered_map<std::string,
822  std::string>& options,
823  std::unordered_map<std::string,
824  std::string>* attrs);
825 
830  public:
832 
833  explicit CustomPartitioner(const char* backend_name);
834 
835  CustomPartitioner& addStrategy(const char* prop_name,
836  const char* sg_name);
837 
838  CustomPartitioner& setSupportedOps(const char* prop_name, supportedOps_t fn);
839 
840  CustomPartitioner& setCreateSelector(const char* prop_name, createSelector_t fn);
841 
842  CustomPartitioner& setReviewSubgraph(const char* prop_name, reviewSubgraph_t fn);
843 
844  supportedOps_t getSupportedOps(int stg_id);
845 
846  createSelector_t getCreateSelector(int stg_id);
847 
848  reviewSubgraph_t getReviewSubgraph(int stg_id);
849 
851  const char* name;
852  std::map<std::string, supportedOps_t> supported_map;
853  std::map<std::string, createSelector_t> selector_map;
854  std::map<std::string, reviewSubgraph_t> review_map;
856  std::vector<const char*> strategies;
858  std::vector<const char*> op_names;
859 };
860 
865 template <class T>
866 class Registry {
867  public:
872  static Registry* get() PRIVATE_SYMBOL {
873  static Registry inst;
874  return &inst;
875  }
880  T& add(const char* name) {
881  T *entry = new T(name);
882  entries.push_back(entry);
883  return *entry;
884  }
885  int size() {
886  return entries.size();
887  }
888  T& get(int idx) {
889  return *(entries.at(idx));
890  }
891 
892  private:
894  Registry() {}
896  ~Registry() {}
898  std::vector<T*> entries;
899 };
900 
906 #define MX_STR_CONCAT_(__a, __b) __a ## __b
907 #define MX_STR_CONCAT(__a, __b) MX_STR_CONCAT_(__a, __b)
908 
910 #define MX_STRINGIFY(x) #x
911 #define MX_TOSTRING(x) MX_STRINGIFY(x)
912 
914 #define MX_REGISTER_NAME_(Name) MXNet ## _CustomOp ## _
915 #define MX_REGISTER_DEF_(Name) mxnet::ext::CustomOp MX_REGISTER_NAME_(Name)
916 
917 #define MX_REGISTER_PROP_NAME_(Name) MXNet ## _CustomSubProp ## _
918 #define MX_REGISTER_PROP_DEF_(Name) mxnet::ext::CustomPartitioner MX_REGISTER_PROP_NAME_(Name)
919 
920 #define MX_REGISTER_PASS_NAME_(Name) MXNet ## _CustomPass ## _
921 #define MX_REGISTER_PASS_DEF_(Name) mxnet::ext::CustomPass MX_REGISTER_PASS_NAME_(Name)
922 
924 #define REGISTER_OP(Name) MX_STR_CONCAT(MX_REGISTER_DEF_(Name), __COUNTER__) = \
925  mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->add(MX_TOSTRING(Name))
926 
927 #define REGISTER_PARTITIONER(Name) \
928  MX_STR_CONCAT(MX_REGISTER_PROP_DEF_(Name), __COUNTER__) = \
929  mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->add(MX_TOSTRING(Name))
930 
931 #define REGISTER_PASS(Name) \
932  MX_STR_CONCAT(MX_REGISTER_PASS_DEF_(Name), __COUNTER__) = \
933  mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->add(MX_TOSTRING(Name))
934 
935 /* -------------- BELOW ARE CTYPE FUNCTIONS PROTOTYPES --------------- */
936 
942 #define MXLIB_OPREGSIZE_STR "_opRegSize"
943 typedef int (*opRegSize_t)(void);
944 
945 #define MXLIB_OPREGGET_STR "_opRegGet"
946 typedef int (*opRegGet_t)(int idx, const char** name, int *isSGop,
947  const char*** forward_ctx, mxnet::ext::fcomp_t** forward_fp,
948  int* forward_count, const char*** backward_ctx,
949  mxnet::ext::fcomp_t** backward_fp, int* backward_count,
950  const char*** create_op_ctx, mxnet::ext::createOpState_t** create_op_fp,
951  int* create_op_count, mxnet::ext::parseAttrs_t* parse,
954 
955 #define MXLIB_OPCALLFREE_STR "_opCallFree"
956 typedef int (*opCallFree_t)(void* ptr);
957 
958 #define MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs"
959 typedef int (*opCallParseAttrs_t)(parseAttrs_t parseAttrs, const char* const* keys,
960  const char* const* vals, int num,
961  int* num_in, int* num_out);
962 
963 #define MXLIB_OPCALLINFERSHAPE_STR "_opCallInferShape"
964 typedef int (*opCallInferShape_t)(inferShape_t inferShape, const char* const* keys,
965  const char* const* vals, int num,
966  unsigned int** inshapes, int* indims, int num_in,
967  unsigned int*** mod_inshapes, int** mod_indims,
968  unsigned int*** outshapes, int** outdims, int num_out);
969 
970 #define MXLIB_OPCALLINFERTYPE_STR "_opCallInferType"
971 typedef int (*opCallInferType_t)(inferType_t inferType, const char* const* keys,
972  const char* const* vals, int num,
973  int* intypes, int num_in, int* outtypes, int num_out);
974 
975 #define MXLIB_OPCALLINFERSTYPE_STR "_opCallInferSType"
976 typedef int (*opCallInferSType_t)(inferSType_t inferSType, const char* const* keys,
977  const char* const* vals, int num,
978  int* intypes, int num_in, int* outtypes, int num_out);
979 
980 #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute"
981 typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys,
982  const char* const* vals, int num,
983  const int64_t** inshapes, int* indims,
984  void** indata, int* intypes,
985  size_t* inIDs, const char** indev_type,
986  int* indev_id, int num_in,
987  const int64_t** outshapes, int* outdims,
988  void** outdata, int* outtypes,
989  size_t* outIDs, const char** outdev_type,
990  int* outdev_id, int num_out,
991  xpu_malloc_t cpu_malloc, void* cpu_alloc,
992  xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream,
993  sparse_malloc_t sparse_malloc, void* sparse_alloc,
994  int* instypes, int* outstypes,
995  void** in_indices, void** out_indices,
996  void** in_indptr, void** out_indptr,
997  int64_t* in_indices_shapes, int64_t* out_indices_shapes,
998  int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
999  void* rng_cpu_states, void* rng_gpu_states);
1000 
1001 #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs"
1002 typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* keys,
1003  const char* const* vals, int num,
1004  int** mutate_indices, int* indices_size);
1005 
1006 #define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState"
1007 typedef int (*opCallCreateOpState_t)(createOpState_t create_op, const char* const* keys,
1008  const char* const* vals, int num, const char* dev_type,
1009  int dev_id, unsigned int** inshapes, int* indims,
1010  int num_in, const int* intypes, void** state_op);
1011 
1012 #define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute"
1013 typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op,
1014  const int64_t** inshapes, int* indims,
1015  void** indata, int* intypes,
1016  size_t* inIDs, const char** indev_type,
1017  int* indev_id, int num_in,
1018  const int64_t** outshapes, int* outdims,
1019  void** outdata, int* outtypes,
1020  size_t* outIDs, const char** outdev_type,
1021  int* outdev_id, int num_out,
1022  xpu_malloc_t cpu_malloc, void* cpu_alloc,
1023  xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream,
1024  sparse_malloc_t sparse_malloc, void* sparse_alloc,
1025  int* instypes, int* outstypes,
1026  void** in_indices, void** out_indices,
1027  void** in_indptr, void** out_indptr,
1028  int64_t* in_indices_shapes, int64_t* out_indices_shapes,
1029  int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
1030  void* rng_cpu_states, void* rng_gpu_states);
1031 
1032 #define MXLIB_PARTREGSIZE_STR "_partRegSize"
1033 typedef int (*partRegSize_t)(void);
1034 
1035 #define MXLIB_PARTREGGETCOUNT_STR "_partRegGetCount"
1036 typedef int (*partRegGetCount_t)(int idx, const char** name);
1037 
1038 #define MXLIB_PARTREGGET_STR "_partRegGet"
1039 typedef void (*partRegGet_t)(int part_idx, int stg_idx, const char** strategy,
1040  supportedOps_t* supportedOps, createSelector_t* createSelector,
1041  reviewSubgraph_t* reviewSubgraph, const char** op_name);
1042 
1043 #define MXLIB_PARTCALLSUPPORTEDOPS_STR "_partCallSupportedOps"
1044 typedef int (*partCallSupportedOps_t)(supportedOps_t supportedOps, const char *json,
1045  int num_ids, int *ids, const char* const* opt_keys,
1046  const char* const* opt_vals, int num_opts);
1047 
1048 #define MXLIB_PARTCALLCREATESELECTOR_STR "_partCallCreateSelector"
1049 typedef int (*partCallCreateSelector_t)(createSelector_t createSelector, const char *json,
1050  void** selector, const char* const* opt_keys,
1051  const char* const* opt_vals, int num_opts);
1052 
1053 #define MXLIB_PARTCALLSELECT_STR "_partCallSelect"
1054 typedef void (*partCallSelect_t)(void* sel_inst, int nodeID, int* selected);
1055 
1056 #define MXLIB_PARTCALLSELECTINPUT_STR "_partCallSelectInput"
1057 typedef void (*partCallSelectInput_t)(void* sel_inst, int nodeID, int input_nodeID,
1058  int* selected);
1059 
1060 #define MXLIB_PARTCALLSELECTOUTPUT_STR "_partCallSelectOutput"
1061 typedef void (*partCallSelectOutput_t)(void* sel_inst, int nodeID, int output_nodeID,
1062  int* selected);
1063 
1064 #define MXLIB_PARTCALLFILTER_STR "_partCallFilter"
1065 typedef void (*partCallFilter_t)(void* sel_inst, int* candidates, int num_candidates,
1066  int** keep, int* num_keep);
1067 
1068 #define MXLIB_PARTCALLRESET_STR "_partCallReset"
1069 typedef void (*partCallReset_t)(void* sel_inst);
1070 
1071 #define MXLIB_PARTCALLREVIEWSUBGRAPH_STR "_partCallReviewSubgraph"
1072 typedef int (*partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph, const char *json,
1073  int subgraph_id, int *accept, const char* const* opt_keys,
1074  const char* const* opt_vals, int num_opts,
1075  char*** attr_keys, char*** attr_vals, int *num_attrs,
1076  const char* const* arg_names, int num_args,
1077  void* const* arg_data, const int64_t* const* arg_shapes,
1078  const int* arg_dims, const int* arg_types,
1079  const size_t* arg_IDs, const char* const* arg_dev_type,
1080  const int* arg_dev_id,
1081  const char* const* aux_names, int num_aux,
1082  void* const* aux_data, const int64_t* const* aux_shapes,
1083  const int* aux_dims, const int* aux_types,
1084  const size_t* aux_IDs, const char* const* aux_dev_type,
1085  const int* aux_dev_id);
1086 
1087 #define MXLIB_PASSREGSIZE_STR "_passRegSize"
1088 typedef int (*passRegSize_t)(void);
1089 
1090 #define MXLIB_PASSREGGET_STR "_passRegGet"
1091 typedef void (*passRegGet_t)(int pass_idx, graphPass_t* graphPass, const char** pass_name);
1092 
1093 #define MXLIB_PASSCALLGRAPHPASS_STR "_passCallGraphPass"
1094 typedef int (*passCallGraphPass_t)(graphPass_t graphPass, const char *in_graph,
1095  char** out_graph, const char* const* opt_keys,
1096  const char* const* opt_vals, int num_opts,
1097  const char* pass_name, const char* const* arg_names,
1098  int num_args, void* const* arg_data,
1099  const int64_t* const* arg_shapes, const int* arg_dims,
1100  const int* arg_types, const size_t* arg_IDs,
1101  const char* const* arg_dev_type, const int* arg_dev_id,
1102  const char* const* aux_names, int num_aux,
1103  void* const* aux_data, const int64_t* const* aux_shapes,
1104  const int* aux_dims, const int* aux_types,
1105  const size_t* aux_IDs, const char* const* aux_dev_type,
1106  const int* aux_dev_id, nd_malloc_t nd_malloc,
1107  const void* nd_alloc);
1108 
1109 #define MXLIB_INITIALIZE_STR "initialize"
1110 typedef int (*initialize_t)(int version);
1111 
1112 #define MXLIB_OPVERSION_STR "_opVersion"
1113 typedef int (*opVersion_t)();
1114 
1115 #define MXLIB_MSGSIZE_STR "_msgSize"
1116 typedef int (*msgSize_t)(void);
1117 
1118 #define MXLIB_MSGGET_STR "_msgGet"
1119 typedef int (*msgGet_t)(int idx, const char** msg);
1120 
1121 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
1122 #define MX_INT_RET __declspec(dllexport) int __cdecl
1123 #define MX_VOID_RET __declspec(dllexport) void __cdecl
1124 #else
1125 #define MX_INT_RET int
1126 #define MX_VOID_RET void
1127 #endif
1128 
1129 } // namespace ext
1130 } // namespace mxnet
1131 
1132 extern "C" {
1135 
1138 
1140  MX_VOID_RET _opRegGet(int idx, const char** name, int *isSGop,
1141  const char*** forward_ctx, mxnet::ext::fcomp_t** forward_fp,
1142  int* forward_count, const char*** backward_ctx,
1143  mxnet::ext::fcomp_t** backward_fp, int* backward_count,
1144  const char*** create_op_ctx, mxnet::ext::createOpState_t** create_op_fp,
1145  int* create_op_count, mxnet::ext::parseAttrs_t* parse,
1148 
1150  MX_VOID_RET _opCallFree(void* ptr);
1151 
1153  MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, const char* const* keys,
1154  const char* const* vals, int num,
1155  int* num_in, int* num_out);
1156 
1158  MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char* const* keys,
1159  const char* const* vals, int num,
1160  unsigned int** inshapes, int* indims, int num_in,
1161  unsigned int*** mod_inshapes, int** mod_indims,
1162  unsigned int*** outshapes, int** outdims, int num_out);
1163 
1165  MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, const char* const* keys,
1166  const char* const* vals, int num,
1167  int* intypes, int num_in, int* outtypes, int num_out);
1168 
1170  MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, const char* const* keys,
1171  const char* const* vals, int num,
1172  int* instypes, int num_in, int* outstypes, int num_out);
1173 
1175  MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char* const* keys,
1176  const char* const* vals,
1177  int num, const int64_t** inshapes, int* indims, void** indata,
1178  int* intypes, size_t* inIDs, const char** indev_type, int* indev_id,
1179  int num_in, const int64_t** outshapes, int* outdims, void** outdata,
1180  int* outtypes, size_t* outIDs, const char** outdev_type,
1181  int* outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc,
1182  void* cpu_alloc,
1183  mxnet::ext::xpu_malloc_t gpu_malloc, void* gpu_alloc,
1184  void* cuda_stream,
1185  mxnet::ext::sparse_malloc_t sparse_malloc, void* sparse_alloc,
1186  int* instypes, int* outstypes, void** in_indices, void** out_indices,
1187  void** in_indptr, void** out_indptr,
1188  int64_t* in_indices_shapes, int64_t* out_indices_shapes,
1189  int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
1190  void* rng_cpu_states, void* rng_gpu_states);
1191 
1193  MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char* const* keys,
1194  const char* const* vals, int num,
1195  int** mutate_indices, int* indices_size);
1196 
1198  MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys,
1199  const char* const* vals, int num, const char* dev_type,
1200  int dev_id, unsigned int** inshapes, int* indims,
1201  int num_in, const int* intypes, void** state_op);
1202 
1204  MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes,
1205  int* indims, void** indata, int* intypes, size_t* inIDs,
1206  const char** indev_type, int* indev_id, int num_in,
1207  const int64_t** outshapes, int* outdims, void** outdata,
1208  int* outtypes, size_t* outIDs, const char** outdev_type,
1209  int* outdev_id, int num_out,
1210  mxnet::ext::xpu_malloc_t cpu_malloc,
1211  void* cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc,
1212  void* gpu_alloc,
1213  void* stream, mxnet::ext::sparse_malloc_t sparse_malloc,
1214  void* sparse_alloc, int* instypes, int* outstypes,
1215  void** in_indices, void** out_indices, void** in_indptr,
1216  void** out_indptr, int64_t* in_indices_shapes,
1217  int64_t* out_indices_shapes, int64_t* in_indptr_shapes,
1218  int64_t* out_indptr_shapes,
1219  void* rng_cpu_states, void* rng_gpu_states);
1220 
1223 
1224  /* returns number of strategies registered for partitioner
1225  * at specified index */
1226  MX_INT_RET _partRegGetCount(int idx, const char** name);
1227 
1229  MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char** strategy,
1230  mxnet::ext::supportedOps_t* supportedOps,
1231  mxnet::ext::createSelector_t* createSelector,
1232  mxnet::ext::reviewSubgraph_t* reviewSubgraph, const char** op_name);
1233 
1235  MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const char *json,
1236  int num_ids, int *ids, const char* const* opt_keys,
1237  const char* const* opt_vals, int num_opts);
1238 
1240  MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, const char *json,
1241  void** selector, const char* const* opt_keys,
1242  const char* const* opt_vals, int num_opts);
1243 
1245  MX_VOID_RET _partCallSelect(void* sel_inst, int nodeID, int* selected);
1246 
1248  MX_VOID_RET _partCallSelectInput(void* sel_inst, int nodeID,
1249  int input_nodeID, int* selected);
1250 
1252  MX_VOID_RET _partCallSelectOutput(void* sel_inst, int nodeID,
1253  int output_nodeID, int* selected);
1254 
1256  MX_VOID_RET _partCallFilter(void* sel_inst, int* candidates, int num_candidates,
1257  int** keep, int* num_keep);
1258 
1260  MX_VOID_RET _partCallReset(void* sel_inst);
1261 
1263  MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, const char *json,
1264  int subgraph_id, int *accept, const char* const* opt_keys,
1265  const char* const* opt_vals, int num_opts,
1266  char*** attr_keys, char*** attr_vals, int *num_attrs,
1267  const char* const* arg_names, int num_args,
1268  void* const* arg_data, const int64_t* const* arg_shapes,
1269  const int* arg_dims, const int* arg_types,
1270  const size_t* arg_IDs, const char* const* arg_dev_type,
1271  const int* arg_dev_id,
1272  const char* const* aux_names, int num_aux,
1273  void* const* aux_data, const int64_t* const* aux_shapes,
1274  const int* aux_dims, const int* aux_types,
1275  const size_t* aux_IDs, const char* const* aux_dev_type,
1276  const int* aux_dev_id);
1277 
1280 
1282  MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t* graphPass,
1283  const char** pass_name);
1284 
1286  MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *json,
1287  char** out_graph, const char* const* opt_keys,
1288  const char* const* opt_vals, int num_opts,
1289  const char* pass_name, const char* const* arg_names, int num_args,
1290  void* const* arg_data, const int64_t* const* arg_shapes,
1291  const int* arg_dims, const int* arg_types,
1292  const size_t* arg_IDs, const char* const* arg_dev_type,
1293  const int* arg_dev_id, const char* const* aux_names, int num_aux,
1294  void* const* aux_data, const int64_t* const* aux_shapes,
1295  const int* aux_dims, const int* aux_types,
1296  const size_t* aux_IDs, const char* const* aux_dev_type,
1297  const int* aux_dev_id, mxnet::ext::nd_malloc_t nd_malloc,
1298  const void* nd_alloc);
1299 
1307 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
1308  __declspec(dllexport) mxnet::ext::MXReturnValue __cdecl
1309 #else
1311 #endif
1312  initialize(int version);
1313 
1314  MX_INT_RET _msgSize();
1315 
1317  MX_VOID_RET _msgGet(int idx, const char** msg);
1318 } // extern "C"
1319 
1320 #endif // MXNET_LIB_API_H_
An abstract class for subgraph property.
Definition: lib_api.h:829
int(* opVersion_t)()
Definition: lib_api.h:1113
int(* partCallSupportedOps_t)(supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1044
const char * name
partitioner name
Definition: lib_api.h:851
MXReturnValue(* inferShape_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< std::vector< unsigned int > > *in_shapes, std::vector< std::vector< unsigned int > > *out_shapes)
Definition: lib_api.h:726
DLDeviceType
The device type in DLContext.
Definition: lib_api.h:102
std::string getShapeAt(const std::string &shape, unsigned index)
int(* opCallParseAttrs_t)(parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
Definition: lib_api.h:959
MX_VOID_RET _opRegGet(int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate)
returns operator registration at specified index
virtual MXReturnValue Backward(std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &op_res)
Definition: lib_api.h:692
int(* msgGet_t)(int idx, const char **msg)
Definition: lib_api.h:1119
const char * name
operator name
Definition: lib_api.h:768
Definition: lib_api.h:258
MXTensor * tensor
Definition: lib_api.h:572
virtual void Filter(const std::vector< int > &candidates, std::vector< int > *keep)
Definition: lib_api.h:672
provide resource APIs memory allocation mechanism to Forward/Backward functions
Definition: lib_api.h:418
Definition: lib_api.h:256
void(* partCallReset_t)(void *sel_inst)
Definition: lib_api.h:1069
std::vector< const char * > strategies
strategy names
Definition: lib_api.h:856
OpenCL devices.
Definition: lib_api.h:113
std::vector< NodeEntry > outputs
Definition: lib_api.h:574
DLDataTypeCode
The type code options DLDataType.
Definition: lib_api.h:143
MX_VOID_RET _partCallSelectOutput(void *sel_inst, int nodeID, int output_nodeID, int *selected)
returns status of calling select output function from library
StatefulOp wrapper class to pass to backend OpState.
Definition: lib_api.h:701
MX_INT_RET _opRegSize()
returns number of ops registered in this library
T & add(const char *name)
add a new entry
Definition: lib_api.h:880
Metal for Apple GPU.
Definition: lib_api.h:117
Definition: lib_api.h:297
int(* passCallGraphPass_t)(graphPass_t graphPass, const char *in_graph, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, nd_malloc_t nd_malloc, const void *nd_alloc)
Definition: lib_api.h:1094
namespace of mxnet
Definition: api_registry.h:33
Definition: lib_api.h:146
void(* partCallSelectInput_t)(void *sel_inst, int nodeID, int input_nodeID, int *selected)
Definition: lib_api.h:1057
#define MX_VOID_RET
Definition: lib_api.h:1126
ROCm GPUs for AMD GPUs.
Definition: lib_api.h:121
A Device context for Tensor and operator.
Definition: dlpack.h:69
std::vector< Graph * > subgraphs
Definition: lib_api.h:575
CUDA GPU device.
Definition: lib_api.h:106
Node * node
Definition: lib_api.h:550
void(* partRegGet_t)(int part_idx, int stg_idx, const char **strategy, supportedOps_t *supportedOps, createSelector_t *createSelector, reviewSubgraph_t *reviewSubgraph, const char **op_name)
Definition: lib_api.h:1039
MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
returns status of calling mutateInputs function for operator from library
int64_t * indices
Definition: lib_api.h:306
int64_t data_len
Definition: lib_api.h:301
MXStorageType
Definition: lib_api.h:264
Reserved extension device type, used for quickly test extension device The semantics can differ depen...
Definition: lib_api.h:127
Pinned CUDA GPU device by cudaMallocHost.
Definition: lib_api.h:111
inferSType_t infer_storage_type
Definition: lib_api.h:773
MXContext ctx
Definition: lib_api.h:360
#define PRIVATE_SYMBOL
For loading multiple custom op libraries in Linux, exporting same symbol multiple times may lead to u...
Definition: lib_api.h:66
MX_INT_RET _partRegSize()
returns number of partitioners registered in this library
Definition: lib_api.h:496
MXReturnValue(* graphPass_t)(mxnet::ext::Graph *graph, const std::unordered_map< std::string, std::string > &options)
Custom Pass Create function template.
Definition: lib_api.h:792
MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling create selector function from library
MX_VOID_RET _partCallSelect(void *sel_inst, int nodeID, int *selected)
returns status of calling select function from library
Definition: lib_api.h:255
int num
Definition: lib_api.h:536
Tensor data structure used by custom operator.
Definition: lib_api.h:321
MX_VOID_RET _partCallFilter(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
returns status of calling filter function from library
std::unordered_map< std::string, std::string > attrs
Definition: lib_api.h:576
void *(* xpu_malloc_t)(void *, int)
resource malloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:371
int(* opCallInferShape_t)(inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
Definition: lib_api.h:964
MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *instypes, int num_in, int *outstypes, int num_out)
returns status of calling inferSType function for operator from library
MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling supported ops function from library
MX_VOID_RET _partCallSelectInput(void *sel_inst, int nodeID, int input_nodeID, int *selected)
returns status of calling select input function from library
MX_INT_RET _msgSize()
MXReturnValue(* createOpState_t)(const std::unordered_map< std::string, std::string > &attributes, const MXContext &ctx, const std::vector< std::vector< unsigned int > > &in_shapes, const std::vector< int > in_types, CustomStatefulOp **)
Definition: lib_api.h:733
MXReturnValue
Definition: lib_api.h:291
void(* partCallSelectOutput_t)(void *sel_inst, int nodeID, int output_nodeID, int *selected)
Definition: lib_api.h:1061
std::string name
Definition: lib_api.h:571
MX_INT_RET _opCallFStatefulCompute(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Stateful Forward/Backward for operator from library
CustomStatefulOp * get_instance()
Definition: lib_api.h:704
std::map< std::string, createSelector_t > selector_map
Definition: lib_api.h:853
std::vector< NodeEntry > inputs
Definition: lib_api.h:573
Definition: lib_api.h:254
CustomStatefulOpWrapper(CustomStatefulOp *inst)
Definition: lib_api.h:703
graphPass_t pass
pass function
Definition: lib_api.h:808
Definition: lib_api.h:257
std::string str
Definition: lib_api.h:537
MXReturnValue(* createSelector_t)(const mxnet::ext::Graph *graph, CustomOpSelector **sel_inst, const std::unordered_map< std::string, std::string > &options)
Definition: lib_api.h:815
MXDType dtype
Definition: lib_api.h:354
int(* opCallMutateInputs_t)(mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
Definition: lib_api.h:1002
Definition: lib_api.h:144
inferType_t infer_type
Definition: lib_api.h:772
std::map< std::string, JsonVal > attrs
Definition: lib_api.h:639
void(* partCallSelect_t)(void *sel_inst, int nodeID, int *selected)
Definition: lib_api.h:1054
int entry
Definition: lib_api.h:551
Definition: lib_api.h:251
Definition: lib_api.h:221
Class to hold custom operator registration.
Definition: lib_api.h:743
MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
returns status of calling inferType function for operator from library
Definition: lib_api.h:583
int(* opRegGet_t)(int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate)
Definition: lib_api.h:946
Vulkan buffer for next generation graphics.
Definition: lib_api.h:115
MX_INT_RET _opVersion()
returns MXNet library version
inferShape_t infer_shape
Definition: lib_api.h:774
#define MX_ERROR_MSG
Definition: lib_api.h:245
int64_t indices_len
Definition: lib_api.h:307
MXReturnValue(* inferSType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_storage_types, std::vector< int > *out_storage_types)
Definition: lib_api.h:722
std::vector< const char * > op_names
subgraph operator name
Definition: lib_api.h:858
MXReturnValue(* reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id, bool *accept, const std::unordered_map< std::string, std::string > &options, std::unordered_map< std::string, std::string > *attrs)
Definition: lib_api.h:819
DLDeviceType
The device type in DLContext.
Definition: dlpack.h:38
MX_INT_RET _partRegGetCount(int idx, const char **name)
void * mx_gpu_rand_t
Definition: lib_api.h:384
Verilog simulator buffer.
Definition: lib_api.h:119
mx_stream_t get_cuda_stream() const
return the cuda stream object with correct type
Definition: lib_api.h:432
std::vector< createOpState_t > create_op_fp
Definition: lib_api.h:781
MXReturnValue(* fcomp_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &res)
Custom Operator function templates.
Definition: lib_api.h:710
mxnet::ext::MXReturnValue initialize(int version)
Checks if the MXNet version is supported by the library. If supported, initializes the library...
JsonType
Json utility to parse serialized subgraph symbol.
Definition: lib_api.h:496
MX_VOID_RET _opCallFree(void *ptr)
calls free from the external library for library allocated arrays
MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
returns status of calling review subgraph function from library
int(* opCallInferSType_t)(inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:976
MX_INT_RET _passRegSize()
returns number of graph passes registered in this library
Definition: lib_api.h:145
int(* passRegSize_t)(void)
Definition: lib_api.h:1088
definition of JSON objects
Definition: lib_api.h:499
Definition: lib_api.h:549
MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *json, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, mxnet::ext::nd_malloc_t nd_malloc, const void *nd_alloc)
returns status of calling graph pass function from library
int(* partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
Definition: lib_api.h:1072
std::string getDtypeAt(const std::string &dtype, unsigned index)
bool isSGop
Definition: lib_api.h:776
MXReturnValue(* inferType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_types, std::vector< int > *out_types)
Definition: lib_api.h:718
int(* opCallFree_t)(void *ptr)
Definition: lib_api.h:956
int(* opCallFStatefulComp_t)(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:1013
size_t verID
Definition: lib_api.h:357
std::vector< JsonVal > list
Definition: lib_api.h:538
MXReturnValue(* supportedOps_t)(const mxnet::ext::Graph *graph, std::vector< int > *ids, const std::unordered_map< std::string, std::string > &options)
Custom Subgraph Create function template.
Definition: lib_api.h:812
#define MX_INT_RET
Definition: lib_api.h:1125
const char * name
pass name
Definition: lib_api.h:806
parseAttrs_t parse_attrs
operator functions
Definition: lib_api.h:771
Definition: lib_api.h:394
data_type * data()
helper function to cast data pointer
Definition: lib_api.h:336
MXReturnValue(* parseAttrs_t)(const std::unordered_map< std::string, std::string > &attributes, int *num_inputs, int *num_outputs)
Definition: lib_api.h:715
mutateInputs_t mutate_inputs
Definition: lib_api.h:775
Context info passing from MXNet OpContext dev_type is string repr of supported context, currently only "cpu" and "gpu" dev_id is the device index where the tensor locates.
Definition: lib_api.h:278
std::mt19937 mx_cpu_rand_t
Definition: lib_api.h:386
MX_VOID_RET _msgGet(int idx, const char **msg)
returns operator registration at specified index
MX_VOID_RET _partCallReset(void *sel_inst)
returns status of calling reset selector function from library
Definition: lib_api.h:292
void * mx_stream_t
GPU stream pointer, is void* when not compiled with CUDA.
Definition: lib_api.h:383
An abstract class for library authors creating stateful op custom library should override Forward and...
Definition: lib_api.h:687
int(* partCallCreateSelector_t)(createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1049
std::map< std::string, supportedOps_t > supported_map
Definition: lib_api.h:852
int size()
Definition: lib_api.h:885
Definition: lib_api.h:293
std::vector< const char * > forward_ctx_cstr
vector repr of ctx map to be easily loaded from c_api
Definition: lib_api.h:779
CPU device.
Definition: lib_api.h:104
int(* partRegSize_t)(void)
Definition: lib_api.h:1033
Definition: lib_api.h:555
void(* passRegGet_t)(int pass_idx, graphPass_t *graphPass, const char **pass_name)
Definition: lib_api.h:1091
Definition: lib_api.h:496
void(* sparse_malloc_t)(void *, int, int, int, void **, int64_t **, int64_t **)
sparse alloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:373
MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t *graphPass, const char **pass_name)
returns pass registration at specified index
void(* partCallFilter_t)(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
Definition: lib_api.h:1065
std::vector< NodeEntry > outputs
Definition: lib_api.h:638
int64_t indptr_len
Definition: lib_api.h:312
int dev_id
Definition: lib_api.h:288
std::vector< Node * > inputs
Definition: lib_api.h:637
DLTensor dltensor
Definition: lib_api.h:364
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op)
returns status of calling createStatefulOp function for operator from library
int(* opCallInferType_t)(inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:971
std::string dev_type
Definition: lib_api.h:287
int(* opCallCreateOpState_t)(createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op)
Definition: lib_api.h:1007
MXDType
Tensor data type, consistent with mshadow data type.
Definition: lib_api.h:250
int(* partRegGetCount_t)(int idx, const char **name)
Definition: lib_api.h:1036
std::vector< fcomp_t > forward_fp
Definition: lib_api.h:780
Definition: lib_api.h:496
MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
returns status of calling inferShape function for operator from library
int(* initialize_t)(int version)
Definition: lib_api.h:1110
std::map< JsonVal, JsonVal > map
Definition: lib_api.h:539
Definition: lib_api.h:270
int(* opRegSize_t)(void)
Definition: lib_api.h:943
An abstract class for graph passes.
Definition: lib_api.h:798
std::vector< int64_t > shape
Definition: lib_api.h:351
Registry class to registers things (ops, properties) Singleton class.
Definition: lib_api.h:866
JsonType type
Definition: lib_api.h:535
Definition: lib_api.h:268
The data type the tensor can hold.
Definition: dlpack.h:94
void(* nd_malloc_t)(const void *_ndarray_alloc, const int64_t *shapes, int num_shapes, const char *dev_str, int dev_id, int dtype, const char *name, int isArg, void **data)
resource malloc function to allocate ndarrays for graph passes
Definition: lib_api.h:375
Plain C Tensor object, does not manage memory.
Definition: dlpack.h:112
MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Forward/Backward function for operator from library
Definition: lib_api.h:649
std::stringstream & add(const char *file, int line)
virtual void Reset()
Definition: lib_api.h:679
Definition: lib_api.h:266
std::string op
Definition: lib_api.h:570
Definition: lib_api.h:252
Definition: lib_api.h:253
Definition: lib_api.h:496
Definition: lib_api.h:496
MXReturnValue(* mutateInputs_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *input_indices)
Definition: lib_api.h:730
int(* msgSize_t)(void)
Definition: lib_api.h:1116
mx_gpu_rand_t * get_gpu_rand_states() const
get pointer to initialized and seeded random number states located on GPU
Definition: lib_api.h:446
std::map< std::string, reviewSubgraph_t > review_map
Definition: lib_api.h:854
MXStorageType stype
Definition: lib_api.h:367
int(* opCallFComp_t)(fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:981
MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char **strategy, mxnet::ext::supportedOps_t *supportedOps, mxnet::ext::createSelector_t *createSelector, mxnet::ext::reviewSubgraph_t *reviewSubgraph, const char **op_name)
returns partitioner registration at specified index
MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
returns status of calling parse attributes function for operator from library
void * data_ptr
Definition: lib_api.h:348