mxnet
lib_api.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
31 #ifndef MXNET_LIB_API_H_
32 #define MXNET_LIB_API_H_
33 
34 #include <stdint.h>
35 #include <stdlib.h>
36 #include <string.h>
37 #include <vector>
38 #include <map>
39 #include <unordered_set>
40 #include <unordered_map>
41 #include <string>
42 #include <iostream>
43 #include <utility>
44 #include <stdexcept>
45 #include <functional>
46 #include <random>
47 #include <sstream>
48 
49 #if defined(__NVCC__)
50  #include <cuda_runtime.h>
51  #include <curand_kernel.h>
52 #endif
53 
54 /* Make sure to update the version number everytime you make changes */
55 #define MX_LIBRARY_VERSION 11
56 
62 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
63  #define PRIVATE_SYMBOL
64 #else
65  #define PRIVATE_SYMBOL __attribute__ ((visibility ("hidden")))
66 #endif
67 
68 /*
69  * Import from DLPack https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
70  */
71 #ifndef DLPACK_VERSION
72 #ifdef __cplusplus
73 #define DLPACK_EXTERN_C extern "C"
74 #else
75 #define DLPACK_EXTERN_C
76 #endif
77 
79 #define DLPACK_VERSION 020
80 
82 #ifdef _WIN32
83 #ifdef DLPACK_EXPORTS
84 #define DLPACK_DLL __declspec(dllexport)
85 #else
86 #define DLPACK_DLL __declspec(dllimport)
87 #endif
88 #else
89 #define DLPACK_DLL
90 #endif
91 
92 #include <stdint.h>
93 #include <stddef.h>
94 
95 #ifdef __cplusplus
96 extern "C" {
97  #endif
98 
101  typedef enum {
103  kDLCPU = 1,
105  kDLGPU = 2,
116  kDLMetal = 8,
118  kDLVPI = 9,
120  kDLROCM = 10,
126  kDLExtDev = 12,
127  } DLDeviceType;
128 
132  typedef struct {
134  DLDeviceType device_type;
136  int device_id;
137  } DLContext;
138 
142  typedef enum {
143  kDLInt = 0U,
144  kDLUInt = 1U,
145  kDLFloat = 2U,
146  } DLDataTypeCode;
147 
156  typedef struct {
162  uint8_t code;
166  uint8_t bits;
168  uint16_t lanes;
169  } DLDataType;
170 
174  typedef struct {
194  void* data;
196  DLContext ctx;
198  int ndim;
200  DLDataType dtype;
202  int64_t* shape;
207  int64_t* strides;
209  uint64_t byte_offset;
210  } DLTensor;
211 #ifdef __cplusplus
212 } // DLPACK_EXTERN_C
213 #endif
214 #endif
215 
216 namespace mxnet {
217 namespace ext {
218 
219 /* \brief Class to store error messages from extensions to pass to MXNet */
220 class MXerrorMsgs {
221  public:
222  /* \brief get singleton pointer to class */
223  static MXerrorMsgs* get();
224 
225  /* \brief add a new error message */
226  std::stringstream& add(const char* file, int line);
227 
228  /* \brief return number of error messages */
229  int size();
230 
231  /* \brief get error message at index */
232  const std::string* get(int idx);
233 
234  private:
236  MXerrorMsgs() {}
238  ~MXerrorMsgs();
240  std::vector<std::stringstream*> messages;
241 };
242 
243 // Add a new error message, example: MX_ERROR_MSG << "my error msg";
244 #define MX_ERROR_MSG mxnet::ext::MXerrorMsgs::get()->add(__FILE__, __LINE__)
245 
249 enum MXDType {
250  kFloat32 = 0,
251  kFloat64 = 1,
252  kFloat16 = 2,
253  kUint8 = 3,
254  kInt32 = 4,
255  kInt8 = 5,
256  kInt64 = 6,
257  kUNSET = 100,
258 };
259 
260 /*
261  * MXTensor storage type.
262  */
264  // dense
266  // row sparse
268  // csr
270 };
271 
277 struct MXContext {
278  MXContext();
279  explicit MXContext(std::string dev_type_, int dev_id_);
280  explicit MXContext(const char* dev_type_, int dev_id_);
281  static MXContext CPU();
282  static MXContext GPU();
283  static MXContext CPU(int dev_id);
284  static MXContext GPU(int dev_id);
285 
286  std::string dev_type;
287  int dev_id;
288 };
289 
291  MX_FAIL = 0,
293 };
294 
295 // For sparse tensors, read/write the data from NDarray via pointers.
296 struct MXSparse {
297  // Pointer to data.
298  void *data{nullptr};
299  // length of (non-zero) data.
300  int64_t data_len;
301 
302  // To store aux data for sparse.
303  // For CSR, indices stores the col index of non-zero elements.
304  // For row sparse, indices store row index of rows which have non-zero elements.
305  int64_t* indices;
306  int64_t indices_len;
307 
308  // For CSR, indptr gives the start and end index of data for each row.
309  // For row sparse, indptr is not used.
310  int64_t* indptr = nullptr;
311  int64_t indptr_len;
312 
313  void set(void *data_ptr, const int64_t* dims, int ndims, void *idx,
314  int64_t num_idx, void *idx_ptr = nullptr, int64_t num_idx_ptr = 0);
315 };
316 
320 struct MXTensor {
321  MXTensor();
322  MXTensor(const MXTensor& oth);
323  MXTensor(void *data_ptr, std::vector<int64_t> shape, MXDType dtype,
324  size_t vID, MXContext mx_ctx, MXStorageType stype = kDefaultStorage);
325 
327  void setTensor(void *dptr, MXDType type, const int64_t* dims, int ndims,
328  size_t vID, MXContext mx_ctx, MXStorageType storage_type);
329 
331  void setDLTensor();
332 
334  template<typename data_type>
335  inline data_type* data() {
336  return reinterpret_cast<data_type*>(data_ptr);
337  }
338 
340  int64_t size() const;
341 
343  bool isSame(const MXTensor &oth) const;
344 
345  // For dense, data_ptr points to 1D flattened tensor data
346  // For sparse, data_ptr points to MXSparse
347  void *data_ptr;
348 
349  // shape is in [2,3,4] format to represent high-dim tensor
350  std::vector<int64_t> shape;
351 
352  // type can only be MXDType enum types
354 
355  // version number updated if the tensor has changed since the last use by custom op
356  size_t verID;
357 
358  // context of MXTensor representing which device the tensor data is located
360 
361  // corresponding DLTensor repr of MXTensor
362  // easy way to reuse functions taking DLTensor
364 
365  // storage type
367 };
368 
370 typedef void* (*xpu_malloc_t)(void*, int);
372 typedef void (*sparse_malloc_t)(void*, int, int, int, void**, int64_t**, int64_t**);
374 typedef void (*nd_malloc_t)(const void* _ndarray_alloc, const int64_t* shapes, int num_shapes,
375  const char* dev_str, int dev_id, int dtype, const char* name,
376  int isArg, void** data);
378 #if defined(__NVCC__)
379  typedef cudaStream_t mx_stream_t;
380  typedef curandStatePhilox4_32_10_t mx_gpu_rand_t;
381 #else
382  typedef void* mx_stream_t;
383  typedef void* mx_gpu_rand_t;
384 #endif
385 typedef std::mt19937 mx_cpu_rand_t;
386 
388 /* Each thread should generate random number unique sequence out of different states */
389 #define MX_NUM_CPU_RANDOM_STATES 1024
390 #define MX_NUM_GPU_RANDOM_STATES 32768
391 
392 /* \brief Class to help allocate new args/aux params in graph passes */
394  public:
395  PassResource(std::unordered_map<std::string, MXTensor>* new_args,
396  std::unordered_map<std::string, MXTensor>* new_aux,
397  nd_malloc_t nd_malloc, const void* nd_alloc);
398 
399  // allocate new arg param, adds to args map, returns newly allocated tensor
400  MXTensor* alloc_arg(const std::string& name, const std::vector<int64_t>& shapes,
401  const MXContext &ctx, MXDType dtype) const;
402 
403  // allocate new aux param, adds to aux map, returns newly allocated tensor
404  MXTensor* alloc_aux(const std::string& name, const std::vector<int64_t>& shapes,
405  const MXContext &ctx, MXDType dtype) const;
406 
407  private:
408  std::unordered_map<std::string, MXTensor>* new_args_;
409  std::unordered_map<std::string, MXTensor>* new_aux_;
410  nd_malloc_t nd_malloc_;
411  const void* nd_alloc_;
412 };
413 
417 class OpResource {
418  public:
419  OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp,
420  xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream,
421  sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp,
422  void* rng_cpu_states, void* rng_gpu_states);
423 
425  void* alloc_cpu(int size) const;
426 
428  void* alloc_gpu(int size) const;
429 
431  inline mx_stream_t get_cuda_stream() const {
432  return static_cast<mx_stream_t>(cuda_stream);
433  }
434 
436  void alloc_sparse(MXSparse* sparse, int index, int indices_len, int indptr_len = 0) const;
437 
439  /* Access each state by states[id], but this id should be <= MX_NUM_CPU_RANDOM_STATES */
440  mx_cpu_rand_t* get_cpu_rand_states() const;
441 
443  /* Access each state by states[id], but this id should be <= MX_NUM_GPU_RANDOM_STATES */
444  /* Note that if you are using cpu build, it will return a nullptr */
445  inline mx_gpu_rand_t* get_gpu_rand_states() const {
446  return static_cast<mx_gpu_rand_t*>(rand_gpu_states);
447  }
448 
449  private:
451  xpu_malloc_t cpu_malloc, gpu_malloc;
453  void *cpu_alloc, *gpu_alloc;
455  void *cuda_stream;
457  sparse_malloc_t sparse_malloc;
459  void *sparse_alloc;
461  void *rand_cpu_states, *rand_gpu_states;
462 };
463 
465 #define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json"
466 
467 #define MX_STR_DTYPE "__ext_dtype__"
468 
469 #define MX_STR_SHAPE "__ext_shape__"
470 
471 #define MX_STR_EXTRA_INPUTS "__ext_extra_inputs__"
472 
473 /* \brief get shape value from list of shapes string
474  *
475  * Examples:
476  *
477  * getShapeAt("[[1]]", 0) returns "[1]"
478  * getShapeAt("[[1],[2,3]]", 1) returns "[2,3]"
479  */
480 std::string getShapeAt(const std::string& shape, unsigned index);
481 
482 /* \brief get dtype value from list of dtypes string
483  *
484  * Examples:
485  *
486  * getDtypeAt("[1]", 0) returns "1"
487  * getDtypeAt("[1,2]", 1) returns "2"
488  */
489 std::string getDtypeAt(const std::string& dtype, unsigned index);
490 
495 enum JsonType {ERR, STR, NUM, LIST, MAP};
496 
498 struct JsonVal {
499  JsonVal(); // default constructor
500  // construct a JSON object by type
501  explicit JsonVal(JsonType t);
502  // construct a string JSON object
503  explicit JsonVal(std::string s);
504  // construct a number JSON object
505  explicit JsonVal(int n);
506  // complex constructor
507  JsonVal(JsonType t, int n, std::string s);
508  bool operator<(const JsonVal &o) const;
509 
510  // convert JSON object back to JSON-compatible string
511  std::string dump() const;
512 
513  // convert JSON-compatible string to JSON object
514  static JsonVal parse(const std::string& json);
515 
516  // parse a string JSON object
517  static JsonVal parse_string(const std::string& json, unsigned int* idx);
518 
519  // parse a number JSON object
520  static JsonVal parse_num(const std::string& json, unsigned int* idx);
521 
522  // parse a list of JSON objects
523  static JsonVal parse_list(const std::string& json, unsigned int* idx);
524 
525  // parse a map of JSON objects
526  static JsonVal parse_map(const std::string& json, unsigned int* idx);
527 
528  // generic parse function
529  static JsonVal parse(const std::string& json, unsigned int *idx);
530 
531  // debug function to convert data structure to a debugstring
532  std::string toString() const;
533 
535  int num;
536  std::string str;
537  std::vector<JsonVal> list;
538  std::map<JsonVal, JsonVal> map;
539 };
540 
544 class Node;
545 class Graph;
546 
547 // Representation of an input/output to a node
548 struct NodeEntry {
549  Node* node; // other node thats producing/consuming inputs/outputs
550  int entry; // entry index from other node (ie. output index from producing node)
551 };
552 
553 // Representation of a node in the graph
554 class Node {
555  public:
556  Node();
557 
558  // internally set passResource to enable tensor allocation for graph passes
559  void _setPassResource(PassResource* res_);
560 
561  /* \brief allocate an arg tensor for this node */
562  void alloc_arg(const std::vector<int64_t>& shapes,
563  const MXContext &ctx, MXDType dtype);
564 
565  /* \brief allocate an aux tensor for this node */
566  void alloc_aux(const std::vector<int64_t>& shapes,
567  const MXContext &ctx, MXDType dtype);
568 
569  std::string op; // operator name (ie. Convolution)
570  std::string name; // unique node name (ie. conv_0 or conv_1)
571  MXTensor* tensor; // tensor data for input nodes
572  std::vector<NodeEntry> inputs; // set of inputs to the node
573  std::vector<NodeEntry> outputs; // set of outputs from the node
574  std::vector<Graph*> subgraphs; // set of subgraphs within this node
575  std::unordered_map<std::string, std::string> attrs; // node attributes
576 
577  private:
578  PassResource* res;
579 };
580 
581 // Representation of the graph
582 class Graph {
583  public:
584  Graph();
585 
586  /* \brief deleted nodes when deleting the graph */
587  ~Graph();
588 
589  /* \brief create a graph object from an unparsed string */
590  static Graph* fromString(const std::string& json);
591 
592  /* \brief create a graph object from a parsed JSON object */
593  static Graph* fromJson(JsonVal val);
594 
595  /* \brief convert graph object back to JSON object */
596  JsonVal toJson() const;
597 
598  /* \brief convert graph object to JSON string */
599  std::string toString() const;
600 
601  /* \brief visits a node "n" */
602  void _dfs_util(Node* n, std::unordered_set<Node*>* to_visit,
603  std::function<void(Node*)> handler) const;
604 
605  /* \brief post-order DFS graph traversal */
606  void DFS(std::function<void(Node*)> handler) const;
607 
608  /* \brief sort graph nodes in topological order */
609  std::vector<Node*> topological_sort() const;
610 
611  /* \brief print out graph details */
612  void print(int indent = 0) const;
613 
614  /* \brief add a new node to this graph */
615  Node* addNode(const std::string& name, const std::string& op);
616 
617  /* \brief get node at index in graph */
618  Node* getNode(size_t idx);
619 
620  /* \brief get const node at index in const graph */
621  const Node* getNode(size_t idx) const;
622 
623  /* \brief get attribute on graph */
624  const JsonVal& getAttr(const std::string& key) const;
625 
626  /* \brief get number of nodes in the graph */
627  size_t size() const;
628 
629  // internally set passResource to enable tensor allocation for graph passes
630  void _setPassResource(PassResource* res_);
631 
632  // internally set arg/aux params when available
633  void _setParams(std::unordered_map<std::string, mxnet::ext::MXTensor>* args,
634  std::unordered_map<std::string, mxnet::ext::MXTensor>* aux);
635 
636  std::vector<Node*> inputs;
637  std::vector<NodeEntry> outputs;
638  std::map<std::string, JsonVal> attrs;
639 
640  private:
641  std::vector<Node*> nodes;
642  PassResource* res;
643 };
644 
645 /* \brief An abstract class for library authors creating custom
646  * partitioners. Optional, can just implement supportedOps instead
647  */
649  public:
650  /* \brief Select a node to include in subgraph, return true to include node
651  * nodeID - index of node in graph
652  */
653  virtual bool Select(int nodeID) = 0;
654  /* \brief Select an input node from current node to include in subgraph
655  * return true to include node
656  * nodeID - index of node in graph
657  * input_nodeID - index of input node in graph
658  */
659  virtual bool SelectInput(int nodeID, int input_nodeID) = 0;
660  /* \brief Select an output node from current node to include in subgraph
661  * return true to include node
662  * nodeID - index of node in graph
663  * output_nodeID - index of output node in graph
664  */
665  virtual bool SelectOutput(int nodeID, int output_nodeID) = 0;
666  /* \brief Review nodes to include in subgraph
667  * return set of candidate nodes to keep in subgraph
668  * candidates - indices of nodes to include in subgraph
669  * keep - indices of nodes to keep in subgraph
670  */
671  virtual void Filter(const std::vector<int>& candidates,
672  std::vector<int>* keep) {
673  keep->insert(keep->end(), candidates.begin(), candidates.end());
674  }
675  /* \brief Reset any selector state, called after growing subgraph, before filter
676  * Called after finished calling SelectInput/SelectOutput and growing subgraph
677  */
678  virtual void Reset() {}
679 };
680 
687  public:
689  virtual ~CustomStatefulOp();
690 
691  template<class A, typename ...Ts>
692  static CustomStatefulOp* create(Ts...args) {
693  CustomStatefulOp* op = new A(args...);
694  op->created = true;
695  return op;
696  }
697 
698  bool wasCreated() { return created; }
699 
700  virtual MXReturnValue Forward(std::vector<MXTensor>* inputs,
701  std::vector<MXTensor>* outputs,
702  const OpResource& op_res) = 0;
703  virtual MXReturnValue Backward(std::vector<MXTensor>* inputs,
704  std::vector<MXTensor>* outputs,
705  const OpResource& op_res) {
706  MX_ERROR_MSG << "Error! Operator does not support backward" << std::endl;
707  return MX_FAIL;
708  }
709 
711 
712  private:
713  bool created;
714 };
715 
717 typedef MXReturnValue (*fcomp_t)(const std::unordered_map<std::string,
718  std::string>& attributes,
719  std::vector<MXTensor>* inputs,
720  std::vector<MXTensor>* outputs,
721  const OpResource& res);
722 typedef MXReturnValue (*parseAttrs_t)(const std::unordered_map<std::string,
723  std::string>& attributes,
724  int* num_inputs, int* num_outputs);
725 typedef MXReturnValue (*inferType_t)(const std::unordered_map<std::string,
726  std::string>& attributes,
727  std::vector<int>* in_types,
728  std::vector<int>* out_types);
729 typedef MXReturnValue (*inferSType_t)(const std::unordered_map<std::string,
730  std::string>& attributes,
731  std::vector<int>* in_storage_types,
732  std::vector<int>* out_storage_types);
733 typedef MXReturnValue (*inferShape_t)(const std::unordered_map<std::string,
734  std::string>& attributes,
735  std::vector<std::vector<unsigned int> >* in_shapes,
736  std::vector<std::vector<unsigned int> >* out_shapes);
737 typedef MXReturnValue (*mutateInputs_t)(const std::unordered_map<std::string,
738  std::string>& attributes,
739  std::vector<int>* input_indices);
740 typedef MXReturnValue (*createOpState_t)(const std::unordered_map<std::string,
741  std::string>& attributes,
742  const MXContext& ctx,
743  const std::vector<std::vector<unsigned int> >& in_shapes,
744  const std::vector<int> in_types,
745  CustomStatefulOp**);
746 
750 class CustomOp {
751  public:
752  explicit CustomOp(const char* op_name);
753 
754  CustomOp& setForward(fcomp_t fcomp, const char* ctx);
755 
756  CustomOp& setBackward(fcomp_t fgrad, const char* ctx);
757 
758  CustomOp& setParseAttrs(parseAttrs_t func);
759 
760  CustomOp& setInferType(inferType_t func);
761 
762  CustomOp& setInferSType(inferSType_t func);
763 
764  CustomOp& setInferShape(inferShape_t func);
765 
766  CustomOp& setMutateInputs(mutateInputs_t func);
767 
768  CustomOp& setCreateOpState(createOpState_t func, const char* ctx);
769 
770  CustomOp& setIsSubgraphOp();
771 
772  void mapToVector();
773 
775  const char* name;
776 
783  bool isSGop;
784 
786  std::vector<const char*> forward_ctx_cstr, backward_ctx_cstr, create_op_ctx_cstr;
787  std::vector<fcomp_t> forward_fp, backward_fp;
788  std::vector<createOpState_t> create_op_fp;
789 
790  private:
791  void raiseDuplicateContextError();
792 
794  std::unordered_map<const char*, fcomp_t> forward_ctx_map, backward_ctx_map;
795  std::unordered_map<const char*, createOpState_t> create_op_ctx_map;
796 };
797 
800  const std::unordered_map<std::string, std::string>& options);
801 
805 class CustomPass {
806  public:
807  CustomPass();
808  explicit CustomPass(const char* pass_name);
809 
810  CustomPass& setBody(graphPass_t fn);
811 
813  const char* name;
816 };
817 
819 typedef MXReturnValue (*supportedOps_t)(const mxnet::ext::Graph *graph, std::vector<int>* ids,
820  const std::unordered_map<std::string,
821  std::string>& options);
823  CustomOpSelector** sel_inst,
824  const std::unordered_map<std::string,
825  std::string>& options);
826 typedef MXReturnValue (*reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id,
827  bool* accept,
828  const std::unordered_map<std::string,
829  std::string>& options,
830  std::unordered_map<std::string,
831  std::string>* attrs);
832 
837  public:
839 
840  explicit CustomPartitioner(const char* backend_name);
841 
842  CustomPartitioner& addStrategy(const char* prop_name,
843  const char* sg_name);
844 
845  CustomPartitioner& setSupportedOps(const char* prop_name, supportedOps_t fn);
846 
847  CustomPartitioner& setCreateSelector(const char* prop_name, createSelector_t fn);
848 
849  CustomPartitioner& setReviewSubgraph(const char* prop_name, reviewSubgraph_t fn);
850 
851  supportedOps_t getSupportedOps(int stg_id);
852 
853  createSelector_t getCreateSelector(int stg_id);
854 
855  reviewSubgraph_t getReviewSubgraph(int stg_id);
856 
858  const char* name;
859  std::map<std::string, supportedOps_t> supported_map;
860  std::map<std::string, createSelector_t> selector_map;
861  std::map<std::string, reviewSubgraph_t> review_map;
863  std::vector<const char*> strategies;
865  std::vector<const char*> op_names;
866 };
867 
872 template <class T>
873 class Registry {
874  public:
879  static Registry* get() PRIVATE_SYMBOL {
880  static Registry inst;
881  return &inst;
882  }
887  T& add(const char* name) {
888  T *entry = new T(name);
889  entries.push_back(entry);
890  return *entry;
891  }
892  int size() {
893  return entries.size();
894  }
895  T& get(int idx) {
896  return *(entries.at(idx));
897  }
898 
899  private:
901  Registry() {}
903  ~Registry() {}
905  std::vector<T*> entries;
906 };
907 
913 #define MX_STR_CONCAT_(__a, __b) __a ## __b
914 #define MX_STR_CONCAT(__a, __b) MX_STR_CONCAT_(__a, __b)
915 
917 #define MX_STRINGIFY(x) #x
918 #define MX_TOSTRING(x) MX_STRINGIFY(x)
919 
921 #define MX_REGISTER_NAME_(Name) MXNet ## _CustomOp ## _ ## Name
922 #define MX_REGISTER_DEF_(Name) mxnet::ext::CustomOp MX_REGISTER_NAME_(Name)
923 
924 #define MX_REGISTER_PROP_NAME_(Name) MXNet ## _CustomSubProp ## _ ## Name
925 #define MX_REGISTER_PROP_DEF_(Name) mxnet::ext::CustomPartitioner MX_REGISTER_PROP_NAME_(Name)
926 
927 #define MX_REGISTER_PASS_NAME_(Name) MXNet ## _CustomPass ## _ ## Name
928 #define MX_REGISTER_PASS_DEF_(Name) mxnet::ext::CustomPass MX_REGISTER_PASS_NAME_(Name)
929 
931 #define REGISTER_OP(Name) MX_STR_CONCAT(MX_REGISTER_DEF_(Name), __COUNTER__) = \
932  mxnet::ext::Registry<mxnet::ext::CustomOp>::get()->add(MX_TOSTRING(Name))
933 
934 #define REGISTER_PARTITIONER(Name) \
935  MX_STR_CONCAT(MX_REGISTER_PROP_DEF_(Name), __COUNTER__) = \
936  mxnet::ext::Registry<mxnet::ext::CustomPartitioner>::get()->add(MX_TOSTRING(Name))
937 
938 #define REGISTER_PASS(Name) \
939  MX_STR_CONCAT(MX_REGISTER_PASS_DEF_(Name), __COUNTER__) = \
940  mxnet::ext::Registry<mxnet::ext::CustomPass>::get()->add(MX_TOSTRING(Name))
941 
942 /* -------------- BELOW ARE CTYPE FUNCTIONS PROTOTYPES --------------- */
943 
949 #define MXLIB_OPREGSIZE_STR "_opRegSize"
950 typedef int (*opRegSize_t)(void);
951 
952 #define MXLIB_OPREGGET_STR "_opRegGet"
953 typedef int (*opRegGet_t)(int idx, const char** name, int *isSGop,
954  const char*** forward_ctx, mxnet::ext::fcomp_t** forward_fp,
955  int* forward_count, const char*** backward_ctx,
956  mxnet::ext::fcomp_t** backward_fp, int* backward_count,
957  const char*** create_op_ctx, mxnet::ext::createOpState_t** create_op_fp,
958  int* create_op_count, mxnet::ext::parseAttrs_t* parse,
961 
962 #define MXLIB_OPCALLFREE_STR "_opCallFree"
963 typedef int (*opCallFree_t)(void* ptr);
964 
965 #define MXLIB_OPCALLPARSEATTRS_STR "_opCallParseAttrs"
966 typedef int (*opCallParseAttrs_t)(parseAttrs_t parseAttrs, const char* const* keys,
967  const char* const* vals, int num,
968  int* num_in, int* num_out);
969 
970 #define MXLIB_OPCALLINFERSHAPE_STR "_opCallInferShape"
971 typedef int (*opCallInferShape_t)(inferShape_t inferShape, const char* const* keys,
972  const char* const* vals, int num,
973  unsigned int** inshapes, int* indims, int num_in,
974  unsigned int*** mod_inshapes, int** mod_indims,
975  unsigned int*** outshapes, int** outdims, int num_out);
976 
977 #define MXLIB_OPCALLINFERTYPE_STR "_opCallInferType"
978 typedef int (*opCallInferType_t)(inferType_t inferType, const char* const* keys,
979  const char* const* vals, int num,
980  int* intypes, int num_in, int* outtypes, int num_out);
981 
982 #define MXLIB_OPCALLINFERSTYPE_STR "_opCallInferSType"
983 typedef int (*opCallInferSType_t)(inferSType_t inferSType, const char* const* keys,
984  const char* const* vals, int num,
985  int* intypes, int num_in, int* outtypes, int num_out);
986 
987 #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute"
988 typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys,
989  const char* const* vals, int num,
990  const int64_t** inshapes, int* indims,
991  void** indata, int* intypes,
992  size_t* inIDs, const char** indev_type,
993  int* indev_id, int num_in,
994  const int64_t** outshapes, int* outdims,
995  void** outdata, int* outtypes,
996  size_t* outIDs, const char** outdev_type,
997  int* outdev_id, int num_out,
998  xpu_malloc_t cpu_malloc, void* cpu_alloc,
999  xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream,
1000  sparse_malloc_t sparse_malloc, void* sparse_alloc,
1001  int* instypes, int* outstypes,
1002  void** in_indices, void** out_indices,
1003  void** in_indptr, void** out_indptr,
1004  int64_t* in_indices_shapes, int64_t* out_indices_shapes,
1005  int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
1006  void* rng_cpu_states, void* rng_gpu_states);
1007 
1008 #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs"
1009 typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* keys,
1010  const char* const* vals, int num,
1011  int** mutate_indices, int* indices_size);
1012 
1013 #define MXLIB_OPCALLCREATEOPSTATE_STR "_opCallCreateOpState"
1014 typedef int (*opCallCreateOpState_t)(createOpState_t create_op, const char* const* keys,
1015  const char* const* vals, int num, const char* dev_type,
1016  int dev_id, unsigned int** inshapes, int* indims,
1017  int num_in, const int* intypes, void** state_op);
1018 
1019 #define MXLIB_OPCALLDESTROYOPSTATE_STR "_opCallDestroyOpState"
1020 typedef int (*opCallDestroyOpState_t)(void* state_op);
1021 
1022 #define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute"
1023 typedef int (*opCallFStatefulComp_t)(int is_forward, void* state_op,
1024  const int64_t** inshapes, int* indims,
1025  void** indata, int* intypes,
1026  size_t* inIDs, const char** indev_type,
1027  int* indev_id, int num_in,
1028  const int64_t** outshapes, int* outdims,
1029  void** outdata, int* outtypes,
1030  size_t* outIDs, const char** outdev_type,
1031  int* outdev_id, int num_out,
1032  xpu_malloc_t cpu_malloc, void* cpu_alloc,
1033  xpu_malloc_t gpu_malloc, void* gpu_alloc, void* stream,
1034  sparse_malloc_t sparse_malloc, void* sparse_alloc,
1035  int* instypes, int* outstypes,
1036  void** in_indices, void** out_indices,
1037  void** in_indptr, void** out_indptr,
1038  int64_t* in_indices_shapes, int64_t* out_indices_shapes,
1039  int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
1040  void* rng_cpu_states, void* rng_gpu_states);
1041 
1042 #define MXLIB_PARTREGSIZE_STR "_partRegSize"
1043 typedef int (*partRegSize_t)(void);
1044 
1045 #define MXLIB_PARTREGGETCOUNT_STR "_partRegGetCount"
1046 typedef int (*partRegGetCount_t)(int idx, const char** name);
1047 
1048 #define MXLIB_PARTREGGET_STR "_partRegGet"
1049 typedef void (*partRegGet_t)(int part_idx, int stg_idx, const char** strategy,
1050  supportedOps_t* supportedOps, createSelector_t* createSelector,
1051  reviewSubgraph_t* reviewSubgraph, const char** op_name);
1052 
1053 #define MXLIB_PARTCALLSUPPORTEDOPS_STR "_partCallSupportedOps"
1054 typedef int (*partCallSupportedOps_t)(supportedOps_t supportedOps, const char *json,
1055  int num_ids, int *ids, const char* const* opt_keys,
1056  const char* const* opt_vals, int num_opts);
1057 
1058 #define MXLIB_PARTCALLCREATESELECTOR_STR "_partCallCreateSelector"
1059 typedef int (*partCallCreateSelector_t)(createSelector_t createSelector, const char *json,
1060  void** selector, const char* const* opt_keys,
1061  const char* const* opt_vals, int num_opts);
1062 
1063 #define MXLIB_PARTCALLSELECT_STR "_partCallSelect"
1064 typedef void (*partCallSelect_t)(void* sel_inst, int nodeID, int* selected);
1065 
1066 #define MXLIB_PARTCALLSELECTINPUT_STR "_partCallSelectInput"
1067 typedef void (*partCallSelectInput_t)(void* sel_inst, int nodeID, int input_nodeID,
1068  int* selected);
1069 
1070 #define MXLIB_PARTCALLSELECTOUTPUT_STR "_partCallSelectOutput"
1071 typedef void (*partCallSelectOutput_t)(void* sel_inst, int nodeID, int output_nodeID,
1072  int* selected);
1073 
1074 #define MXLIB_PARTCALLFILTER_STR "_partCallFilter"
1075 typedef void (*partCallFilter_t)(void* sel_inst, int* candidates, int num_candidates,
1076  int** keep, int* num_keep);
1077 
1078 #define MXLIB_PARTCALLRESET_STR "_partCallReset"
1079 typedef void (*partCallReset_t)(void* sel_inst);
1080 
1081 #define MXLIB_PARTCALLREVIEWSUBGRAPH_STR "_partCallReviewSubgraph"
1082 typedef int (*partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph, const char *json,
1083  int subgraph_id, int *accept, const char* const* opt_keys,
1084  const char* const* opt_vals, int num_opts,
1085  char*** attr_keys, char*** attr_vals, int *num_attrs,
1086  const char* const* arg_names, int num_args,
1087  void* const* arg_data, const int64_t* const* arg_shapes,
1088  const int* arg_dims, const int* arg_types,
1089  const size_t* arg_IDs, const char* const* arg_dev_type,
1090  const int* arg_dev_id,
1091  const char* const* aux_names, int num_aux,
1092  void* const* aux_data, const int64_t* const* aux_shapes,
1093  const int* aux_dims, const int* aux_types,
1094  const size_t* aux_IDs, const char* const* aux_dev_type,
1095  const int* aux_dev_id);
1096 
1097 #define MXLIB_PASSREGSIZE_STR "_passRegSize"
1098 typedef int (*passRegSize_t)(void);
1099 
1100 #define MXLIB_PASSREGGET_STR "_passRegGet"
1101 typedef void (*passRegGet_t)(int pass_idx, graphPass_t* graphPass, const char** pass_name);
1102 
1103 #define MXLIB_PASSCALLGRAPHPASS_STR "_passCallGraphPass"
1104 typedef int (*passCallGraphPass_t)(graphPass_t graphPass, const char *in_graph,
1105  char** out_graph, const char* const* opt_keys,
1106  const char* const* opt_vals, int num_opts,
1107  const char* pass_name, const char* const* arg_names,
1108  int num_args, void* const* arg_data,
1109  const int64_t* const* arg_shapes, const int* arg_dims,
1110  const int* arg_types, const size_t* arg_IDs,
1111  const char* const* arg_dev_type, const int* arg_dev_id,
1112  const char* const* aux_names, int num_aux,
1113  void* const* aux_data, const int64_t* const* aux_shapes,
1114  const int* aux_dims, const int* aux_types,
1115  const size_t* aux_IDs, const char* const* aux_dev_type,
1116  const int* aux_dev_id, nd_malloc_t nd_malloc,
1117  const void* nd_alloc);
1118 
1119 #define MXLIB_INITIALIZE_STR "initialize"
1120 typedef int (*initialize_t)(int version);
1121 
1122 #define MXLIB_OPVERSION_STR "_opVersion"
1123 typedef int (*opVersion_t)();
1124 
1125 #define MXLIB_MSGSIZE_STR "_msgSize"
1126 typedef int (*msgSize_t)(void);
1127 
1128 #define MXLIB_MSGGET_STR "_msgGet"
1129 typedef int (*msgGet_t)(int idx, const char** msg);
1130 
1133  public:
1136  : instance(inst), destroy_(destroy) {}
1137  CustomStatefulOp* get_instance() { return instance; }
1138  private:
1139  CustomStatefulOp* instance;
1140  opCallDestroyOpState_t destroy_;
1141 };
1142 
1143 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
1144 #define MX_INT_RET __declspec(dllexport) int __cdecl
1145 #define MX_VOID_RET __declspec(dllexport) void __cdecl
1146 #else
1147 #define MX_INT_RET int
1148 #define MX_VOID_RET void
1149 #endif
1150 
1151 } // namespace ext
1152 } // namespace mxnet
1153 
1154 extern "C" {
1157 
1160 
1162  MX_VOID_RET _opRegGet(int idx, const char** name, int *isSGop,
1163  const char*** forward_ctx, mxnet::ext::fcomp_t** forward_fp,
1164  int* forward_count, const char*** backward_ctx,
1165  mxnet::ext::fcomp_t** backward_fp, int* backward_count,
1166  const char*** create_op_ctx, mxnet::ext::createOpState_t** create_op_fp,
1167  int* create_op_count, mxnet::ext::parseAttrs_t* parse,
1170 
1172  MX_VOID_RET _opCallFree(void* ptr);
1173 
1175  MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, const char* const* keys,
1176  const char* const* vals, int num,
1177  int* num_in, int* num_out);
1178 
1180  MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char* const* keys,
1181  const char* const* vals, int num,
1182  unsigned int** inshapes, int* indims, int num_in,
1183  unsigned int*** mod_inshapes, int** mod_indims,
1184  unsigned int*** outshapes, int** outdims, int num_out);
1185 
1187  MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, const char* const* keys,
1188  const char* const* vals, int num,
1189  int* intypes, int num_in, int* outtypes, int num_out);
1190 
1192  MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, const char* const* keys,
1193  const char* const* vals, int num,
1194  int* instypes, int num_in, int* outstypes, int num_out);
1195 
1197  MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char* const* keys,
1198  const char* const* vals,
1199  int num, const int64_t** inshapes, int* indims, void** indata,
1200  int* intypes, size_t* inIDs, const char** indev_type, int* indev_id,
1201  int num_in, const int64_t** outshapes, int* outdims, void** outdata,
1202  int* outtypes, size_t* outIDs, const char** outdev_type,
1203  int* outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc,
1204  void* cpu_alloc,
1205  mxnet::ext::xpu_malloc_t gpu_malloc, void* gpu_alloc,
1206  void* cuda_stream,
1207  mxnet::ext::sparse_malloc_t sparse_malloc, void* sparse_alloc,
1208  int* instypes, int* outstypes, void** in_indices, void** out_indices,
1209  void** in_indptr, void** out_indptr,
1210  int64_t* in_indices_shapes, int64_t* out_indices_shapes,
1211  int64_t* in_indptr_shapes, int64_t* out_indptr_shapes,
1212  void* rng_cpu_states, void* rng_gpu_states);
1213 
1215  MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char* const* keys,
1216  const char* const* vals, int num,
1217  int** mutate_indices, int* indices_size);
1218 
1220  MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys,
1221  const char* const* vals, int num, const char* dev_type,
1222  int dev_id, unsigned int** inshapes, int* indims,
1223  int num_in, const int* intypes, void** state_op);
1224 
1226  MX_VOID_RET _opCallDestroyOpState(void* state_op);
1227 
1229  MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes,
1230  int* indims, void** indata, int* intypes, size_t* inIDs,
1231  const char** indev_type, int* indev_id, int num_in,
1232  const int64_t** outshapes, int* outdims, void** outdata,
1233  int* outtypes, size_t* outIDs, const char** outdev_type,
1234  int* outdev_id, int num_out,
1235  mxnet::ext::xpu_malloc_t cpu_malloc,
1236  void* cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc,
1237  void* gpu_alloc,
1238  void* stream, mxnet::ext::sparse_malloc_t sparse_malloc,
1239  void* sparse_alloc, int* instypes, int* outstypes,
1240  void** in_indices, void** out_indices, void** in_indptr,
1241  void** out_indptr, int64_t* in_indices_shapes,
1242  int64_t* out_indices_shapes, int64_t* in_indptr_shapes,
1243  int64_t* out_indptr_shapes,
1244  void* rng_cpu_states, void* rng_gpu_states);
1245 
1248 
1249  /* returns number of strategies registered for partitioner
1250  * at specified index */
1251  MX_INT_RET _partRegGetCount(int idx, const char** name);
1252 
1254  MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char** strategy,
1255  mxnet::ext::supportedOps_t* supportedOps,
1256  mxnet::ext::createSelector_t* createSelector,
1257  mxnet::ext::reviewSubgraph_t* reviewSubgraph, const char** op_name);
1258 
1260  MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const char *json,
1261  int num_ids, int *ids, const char* const* opt_keys,
1262  const char* const* opt_vals, int num_opts);
1263 
1265  MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, const char *json,
1266  void** selector, const char* const* opt_keys,
1267  const char* const* opt_vals, int num_opts);
1268 
1270  MX_VOID_RET _partCallSelect(void* sel_inst, int nodeID, int* selected);
1271 
1273  MX_VOID_RET _partCallSelectInput(void* sel_inst, int nodeID,
1274  int input_nodeID, int* selected);
1275 
1277  MX_VOID_RET _partCallSelectOutput(void* sel_inst, int nodeID,
1278  int output_nodeID, int* selected);
1279 
1281  MX_VOID_RET _partCallFilter(void* sel_inst, int* candidates, int num_candidates,
1282  int** keep, int* num_keep);
1283 
1285  MX_VOID_RET _partCallReset(void* sel_inst);
1286 
1288  MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, const char *json,
1289  int subgraph_id, int *accept, const char* const* opt_keys,
1290  const char* const* opt_vals, int num_opts,
1291  char*** attr_keys, char*** attr_vals, int *num_attrs,
1292  const char* const* arg_names, int num_args,
1293  void* const* arg_data, const int64_t* const* arg_shapes,
1294  const int* arg_dims, const int* arg_types,
1295  const size_t* arg_IDs, const char* const* arg_dev_type,
1296  const int* arg_dev_id,
1297  const char* const* aux_names, int num_aux,
1298  void* const* aux_data, const int64_t* const* aux_shapes,
1299  const int* aux_dims, const int* aux_types,
1300  const size_t* aux_IDs, const char* const* aux_dev_type,
1301  const int* aux_dev_id);
1302 
1305 
1307  MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t* graphPass,
1308  const char** pass_name);
1309 
1311  MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *json,
1312  char** out_graph, const char* const* opt_keys,
1313  const char* const* opt_vals, int num_opts,
1314  const char* pass_name, const char* const* arg_names, int num_args,
1315  void* const* arg_data, const int64_t* const* arg_shapes,
1316  const int* arg_dims, const int* arg_types,
1317  const size_t* arg_IDs, const char* const* arg_dev_type,
1318  const int* arg_dev_id, const char* const* aux_names, int num_aux,
1319  void* const* aux_data, const int64_t* const* aux_shapes,
1320  const int* aux_dims, const int* aux_types,
1321  const size_t* aux_IDs, const char* const* aux_dev_type,
1322  const int* aux_dev_id, mxnet::ext::nd_malloc_t nd_malloc,
1323  const void* nd_alloc);
1324 
1332 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
1333  __declspec(dllexport) mxnet::ext::MXReturnValue __cdecl
1334 #else
1336 #endif
1337  initialize(int version);
1338 
1339  MX_INT_RET _msgSize();
1340 
1342  MX_VOID_RET _msgGet(int idx, const char** msg);
1343 } // extern "C"
1344 
1345 #endif // MXNET_LIB_API_H_
An abstract class for subgraph property.
Definition: lib_api.h:836
int(* opVersion_t)()
Definition: lib_api.h:1123
int(* partCallSupportedOps_t)(supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1054
const char * name
partitioner name
Definition: lib_api.h:858
MXReturnValue(* inferShape_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< std::vector< unsigned int > > *in_shapes, std::vector< std::vector< unsigned int > > *out_shapes)
Definition: lib_api.h:733
DLDeviceType
The device type in DLContext.
Definition: lib_api.h:101
std::string getShapeAt(const std::string &shape, unsigned index)
int(* opCallParseAttrs_t)(parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
Definition: lib_api.h:966
MX_VOID_RET _opRegGet(int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate)
returns operator registration at specified index
virtual MXReturnValue Backward(std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &op_res)
Definition: lib_api.h:703
int(* msgGet_t)(int idx, const char **msg)
Definition: lib_api.h:1129
const char * name
operator name
Definition: lib_api.h:775
Definition: lib_api.h:257
MXTensor * tensor
Definition: lib_api.h:571
virtual void Filter(const std::vector< int > &candidates, std::vector< int > *keep)
Definition: lib_api.h:671
CustomStatefulOpWrapper(CustomStatefulOp *inst, opCallDestroyOpState_t destroy)
Definition: lib_api.h:1135
provide resource APIs memory allocation mechanism to Forward/Backward functions
Definition: lib_api.h:417
Definition: lib_api.h:255
void(* partCallReset_t)(void *sel_inst)
Definition: lib_api.h:1079
std::vector< const char * > strategies
strategy names
Definition: lib_api.h:863
OpenCL devices.
Definition: lib_api.h:112
std::vector< NodeEntry > outputs
Definition: lib_api.h:573
DLDataTypeCode
The type code options DLDataType.
Definition: lib_api.h:142
MX_VOID_RET _partCallSelectOutput(void *sel_inst, int nodeID, int output_nodeID, int *selected)
returns status of calling select output function from library
StatefulOp wrapper class to pass to backend OpState.
Definition: lib_api.h:1132
MX_INT_RET _opRegSize()
returns number of ops registered in this library
T & add(const char *name)
add a new entry
Definition: lib_api.h:887
Metal for Apple GPU.
Definition: lib_api.h:116
Definition: lib_api.h:296
int(* passCallGraphPass_t)(graphPass_t graphPass, const char *in_graph, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, nd_malloc_t nd_malloc, const void *nd_alloc)
Definition: lib_api.h:1104
namespace of mxnet
Definition: api_registry.h:33
Definition: lib_api.h:145
void(* partCallSelectInput_t)(void *sel_inst, int nodeID, int input_nodeID, int *selected)
Definition: lib_api.h:1067
#define MX_VOID_RET
Definition: lib_api.h:1148
ROCm GPUs for AMD GPUs.
Definition: lib_api.h:120
A Device context for Tensor and operator.
Definition: dlpack.h:69
std::vector< Graph * > subgraphs
Definition: lib_api.h:574
CUDA GPU device.
Definition: lib_api.h:105
Node * node
Definition: lib_api.h:549
bool ignore_warn
Definition: lib_api.h:710
void(* partRegGet_t)(int part_idx, int stg_idx, const char **strategy, supportedOps_t *supportedOps, createSelector_t *createSelector, reviewSubgraph_t *reviewSubgraph, const char **op_name)
Definition: lib_api.h:1049
MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
returns status of calling mutateInputs function for operator from library
int64_t * indices
Definition: lib_api.h:305
int64_t data_len
Definition: lib_api.h:300
MXStorageType
Definition: lib_api.h:263
Reserved extension device type, used for quickly test extension device The semantics can differ depen...
Definition: lib_api.h:126
Pinned CUDA GPU device by cudaMallocHost.
Definition: lib_api.h:110
inferSType_t infer_storage_type
Definition: lib_api.h:780
MXContext ctx
Definition: lib_api.h:359
#define PRIVATE_SYMBOL
For loading multiple custom op libraries in Linux, exporting same symbol multiple times may lead to u...
Definition: lib_api.h:65
MX_INT_RET _partRegSize()
returns number of partitioners registered in this library
Definition: lib_api.h:495
MXReturnValue(* graphPass_t)(mxnet::ext::Graph *graph, const std::unordered_map< std::string, std::string > &options)
Custom Pass Create function template.
Definition: lib_api.h:799
MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling create selector function from library
MX_VOID_RET _partCallSelect(void *sel_inst, int nodeID, int *selected)
returns status of calling select function from library
MX_VOID_RET _opCallDestroyOpState(void *state_op)
returns status of deleting StatefulOp instance for operator from library
Definition: lib_api.h:254
int num
Definition: lib_api.h:535
Tensor data structure used by custom operator.
Definition: lib_api.h:320
MX_VOID_RET _partCallFilter(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
returns status of calling filter function from library
std::unordered_map< std::string, std::string > attrs
Definition: lib_api.h:575
void *(* xpu_malloc_t)(void *, int)
resource malloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:370
int(* opCallInferShape_t)(inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
Definition: lib_api.h:971
MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *instypes, int num_in, int *outstypes, int num_out)
returns status of calling inferSType function for operator from library
MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const char *json, int num_ids, int *ids, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
returns status of calling supported ops function from library
MX_VOID_RET _partCallSelectInput(void *sel_inst, int nodeID, int input_nodeID, int *selected)
returns status of calling select input function from library
MX_INT_RET _msgSize()
MXReturnValue(* createOpState_t)(const std::unordered_map< std::string, std::string > &attributes, const MXContext &ctx, const std::vector< std::vector< unsigned int > > &in_shapes, const std::vector< int > in_types, CustomStatefulOp **)
Definition: lib_api.h:740
MXReturnValue
Definition: lib_api.h:290
void(* partCallSelectOutput_t)(void *sel_inst, int nodeID, int output_nodeID, int *selected)
Definition: lib_api.h:1071
std::string name
Definition: lib_api.h:570
MX_INT_RET _opCallFStatefulCompute(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Stateful Forward/Backward for operator from library
CustomStatefulOp * get_instance()
Definition: lib_api.h:1137
std::map< std::string, createSelector_t > selector_map
Definition: lib_api.h:860
std::vector< NodeEntry > inputs
Definition: lib_api.h:572
Definition: lib_api.h:253
graphPass_t pass
pass function
Definition: lib_api.h:815
Definition: lib_api.h:256
std::string str
Definition: lib_api.h:536
MXReturnValue(* createSelector_t)(const mxnet::ext::Graph *graph, CustomOpSelector **sel_inst, const std::unordered_map< std::string, std::string > &options)
Definition: lib_api.h:822
MXDType dtype
Definition: lib_api.h:353
int(* opCallMutateInputs_t)(mutateInputs_t mutate, const char *const *keys, const char *const *vals, int num, int **mutate_indices, int *indices_size)
Definition: lib_api.h:1009
Definition: lib_api.h:143
inferType_t infer_type
Definition: lib_api.h:779
std::map< std::string, JsonVal > attrs
Definition: lib_api.h:638
void(* partCallSelect_t)(void *sel_inst, int nodeID, int *selected)
Definition: lib_api.h:1064
int entry
Definition: lib_api.h:550
Definition: lib_api.h:250
Definition: lib_api.h:220
Class to hold custom operator registration.
Definition: lib_api.h:750
MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
returns status of calling inferType function for operator from library
Definition: lib_api.h:582
int(* opRegGet_t)(int idx, const char **name, int *isSGop, const char ***forward_ctx, mxnet::ext::fcomp_t **forward_fp, int *forward_count, const char ***backward_ctx, mxnet::ext::fcomp_t **backward_fp, int *backward_count, const char ***create_op_ctx, mxnet::ext::createOpState_t **create_op_fp, int *create_op_count, mxnet::ext::parseAttrs_t *parse, mxnet::ext::inferType_t *type, mxnet::ext::inferSType_t *stype, mxnet::ext::inferShape_t *shape, mxnet::ext::mutateInputs_t *mutate)
Definition: lib_api.h:953
Vulkan buffer for next generation graphics.
Definition: lib_api.h:114
MX_INT_RET _opVersion()
returns MXNet library version
inferShape_t infer_shape
Definition: lib_api.h:781
#define MX_ERROR_MSG
Definition: lib_api.h:244
int64_t indices_len
Definition: lib_api.h:306
MXReturnValue(* inferSType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_storage_types, std::vector< int > *out_storage_types)
Definition: lib_api.h:729
std::vector< const char * > op_names
subgraph operator name
Definition: lib_api.h:865
MXReturnValue(* reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id, bool *accept, const std::unordered_map< std::string, std::string > &options, std::unordered_map< std::string, std::string > *attrs)
Definition: lib_api.h:826
DLDeviceType
The device type in DLContext.
Definition: dlpack.h:38
MX_INT_RET _partRegGetCount(int idx, const char **name)
void * mx_gpu_rand_t
Definition: lib_api.h:383
Verilog simulator buffer.
Definition: lib_api.h:118
mx_stream_t get_cuda_stream() const
return the cuda stream object with correct type
Definition: lib_api.h:431
std::vector< createOpState_t > create_op_fp
Definition: lib_api.h:788
MXReturnValue(* fcomp_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< MXTensor > *inputs, std::vector< MXTensor > *outputs, const OpResource &res)
Custom Operator function templates.
Definition: lib_api.h:717
mxnet::ext::MXReturnValue initialize(int version)
Checks if the MXNet version is supported by the library. If supported, initializes the library...
JsonType
Json utility to parse serialized subgraph symbol.
Definition: lib_api.h:495
MX_VOID_RET _opCallFree(void *ptr)
calls free from the external library for library allocated arrays
MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
returns status of calling review subgraph function from library
int(* opCallInferSType_t)(inferSType_t inferSType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:983
MX_INT_RET _passRegSize()
returns number of graph passes registered in this library
Definition: lib_api.h:144
int(* passRegSize_t)(void)
Definition: lib_api.h:1098
definition of JSON objects
Definition: lib_api.h:498
Definition: lib_api.h:548
MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *json, char **out_graph, const char *const *opt_keys, const char *const *opt_vals, int num_opts, const char *pass_name, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id, mxnet::ext::nd_malloc_t nd_malloc, const void *nd_alloc)
returns status of calling graph pass function from library
int(* partCallReviewSubgraph_t)(reviewSubgraph_t reviewSubgraph, const char *json, int subgraph_id, int *accept, const char *const *opt_keys, const char *const *opt_vals, int num_opts, char ***attr_keys, char ***attr_vals, int *num_attrs, const char *const *arg_names, int num_args, void *const *arg_data, const int64_t *const *arg_shapes, const int *arg_dims, const int *arg_types, const size_t *arg_IDs, const char *const *arg_dev_type, const int *arg_dev_id, const char *const *aux_names, int num_aux, void *const *aux_data, const int64_t *const *aux_shapes, const int *aux_dims, const int *aux_types, const size_t *aux_IDs, const char *const *aux_dev_type, const int *aux_dev_id)
Definition: lib_api.h:1082
std::string getDtypeAt(const std::string &dtype, unsigned index)
bool isSGop
Definition: lib_api.h:783
MXReturnValue(* inferType_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *in_types, std::vector< int > *out_types)
Definition: lib_api.h:725
int(* opCallFree_t)(void *ptr)
Definition: lib_api.h:963
int(* opCallFStatefulComp_t)(int is_forward, void *state_op, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:1023
size_t verID
Definition: lib_api.h:356
std::vector< JsonVal > list
Definition: lib_api.h:537
MXReturnValue(* supportedOps_t)(const mxnet::ext::Graph *graph, std::vector< int > *ids, const std::unordered_map< std::string, std::string > &options)
Custom Subgraph Create function template.
Definition: lib_api.h:819
#define MX_INT_RET
Definition: lib_api.h:1147
const char * name
pass name
Definition: lib_api.h:813
parseAttrs_t parse_attrs
operator functions
Definition: lib_api.h:778
Definition: lib_api.h:393
data_type * data()
helper function to cast data pointer
Definition: lib_api.h:335
MXReturnValue(* parseAttrs_t)(const std::unordered_map< std::string, std::string > &attributes, int *num_inputs, int *num_outputs)
Definition: lib_api.h:722
mutateInputs_t mutate_inputs
Definition: lib_api.h:782
bool wasCreated()
Definition: lib_api.h:698
Context info passing from MXNet OpContext dev_type is string repr of supported context, currently only "cpu" and "gpu" dev_id is the device index where the tensor locates.
Definition: lib_api.h:277
std::mt19937 mx_cpu_rand_t
Definition: lib_api.h:385
MX_VOID_RET _msgGet(int idx, const char **msg)
returns operator registration at specified index
MX_VOID_RET _partCallReset(void *sel_inst)
returns status of calling reset selector function from library
Definition: lib_api.h:291
static CustomStatefulOp * create(Ts...args)
Definition: lib_api.h:692
void * mx_stream_t
GPU stream pointer, is void* when not compiled with CUDA.
Definition: lib_api.h:382
An abstract class for library authors creating stateful op custom library should override Forward and...
Definition: lib_api.h:686
int(* partCallCreateSelector_t)(createSelector_t createSelector, const char *json, void **selector, const char *const *opt_keys, const char *const *opt_vals, int num_opts)
Definition: lib_api.h:1059
std::map< std::string, supportedOps_t > supported_map
Definition: lib_api.h:859
int size()
Definition: lib_api.h:892
Definition: lib_api.h:292
std::vector< const char * > forward_ctx_cstr
vector repr of ctx map to be easily loaded from c_api
Definition: lib_api.h:786
CPU device.
Definition: lib_api.h:103
int(* opCallDestroyOpState_t)(void *state_op)
Definition: lib_api.h:1020
int(* partRegSize_t)(void)
Definition: lib_api.h:1043
Definition: lib_api.h:554
void(* passRegGet_t)(int pass_idx, graphPass_t *graphPass, const char **pass_name)
Definition: lib_api.h:1101
Definition: lib_api.h:495
void(* sparse_malloc_t)(void *, int, int, int, void **, int64_t **, int64_t **)
sparse alloc function to allocate memory inside Forward/Backward functions
Definition: lib_api.h:372
MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t *graphPass, const char **pass_name)
returns pass registration at specified index
void(* partCallFilter_t)(void *sel_inst, int *candidates, int num_candidates, int **keep, int *num_keep)
Definition: lib_api.h:1075
std::vector< NodeEntry > outputs
Definition: lib_api.h:637
int64_t indptr_len
Definition: lib_api.h:311
int dev_id
Definition: lib_api.h:287
std::vector< Node * > inputs
Definition: lib_api.h:636
DLTensor dltensor
Definition: lib_api.h:363
MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op)
returns status of calling createStatefulOp function for operator from library
int(* opCallInferType_t)(inferType_t inferType, const char *const *keys, const char *const *vals, int num, int *intypes, int num_in, int *outtypes, int num_out)
Definition: lib_api.h:978
std::string dev_type
Definition: lib_api.h:286
int(* opCallCreateOpState_t)(createOpState_t create_op, const char *const *keys, const char *const *vals, int num, const char *dev_type, int dev_id, unsigned int **inshapes, int *indims, int num_in, const int *intypes, void **state_op)
Definition: lib_api.h:1014
MXDType
Tensor data type, consistent with mshadow data type.
Definition: lib_api.h:249
int(* partRegGetCount_t)(int idx, const char **name)
Definition: lib_api.h:1046
std::vector< fcomp_t > forward_fp
Definition: lib_api.h:787
Definition: lib_api.h:495
MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char *const *keys, const char *const *vals, int num, unsigned int **inshapes, int *indims, int num_in, unsigned int ***mod_inshapes, int **mod_indims, unsigned int ***outshapes, int **outdims, int num_out)
returns status of calling inferShape function for operator from library
int(* initialize_t)(int version)
Definition: lib_api.h:1120
std::map< JsonVal, JsonVal > map
Definition: lib_api.h:538
Definition: lib_api.h:269
int(* opRegSize_t)(void)
Definition: lib_api.h:950
An abstract class for graph passes.
Definition: lib_api.h:805
std::vector< int64_t > shape
Definition: lib_api.h:350
Registry class to registers things (ops, properties) Singleton class.
Definition: lib_api.h:873
JsonType type
Definition: lib_api.h:534
Definition: lib_api.h:267
The data type the tensor can hold.
Definition: dlpack.h:94
void(* nd_malloc_t)(const void *_ndarray_alloc, const int64_t *shapes, int num_shapes, const char *dev_str, int dev_id, int dtype, const char *name, int isArg, void **data)
resource malloc function to allocate ndarrays for graph passes
Definition: lib_api.h:374
Plain C Tensor object, does not manage memory.
Definition: dlpack.h:112
MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, void *cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, mxnet::ext::sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
returns status of calling Forward/Backward function for operator from library
Definition: lib_api.h:648
std::stringstream & add(const char *file, int line)
virtual void Reset()
Definition: lib_api.h:678
Definition: lib_api.h:265
std::string op
Definition: lib_api.h:569
Definition: lib_api.h:251
Definition: lib_api.h:252
Definition: lib_api.h:495
Definition: lib_api.h:495
MXReturnValue(* mutateInputs_t)(const std::unordered_map< std::string, std::string > &attributes, std::vector< int > *input_indices)
Definition: lib_api.h:737
int(* msgSize_t)(void)
Definition: lib_api.h:1126
mx_gpu_rand_t * get_gpu_rand_states() const
get pointer to initialized and seeded random number states located on GPU
Definition: lib_api.h:445
std::map< std::string, reviewSubgraph_t > review_map
Definition: lib_api.h:861
MXStorageType stype
Definition: lib_api.h:366
int(* opCallFComp_t)(fcomp_t fcomp, const char *const *keys, const char *const *vals, int num, const int64_t **inshapes, int *indims, void **indata, int *intypes, size_t *inIDs, const char **indev_type, int *indev_id, int num_in, const int64_t **outshapes, int *outdims, void **outdata, int *outtypes, size_t *outIDs, const char **outdev_type, int *outdev_id, int num_out, xpu_malloc_t cpu_malloc, void *cpu_alloc, xpu_malloc_t gpu_malloc, void *gpu_alloc, void *cuda_stream, sparse_malloc_t sparse_malloc, void *sparse_alloc, int *instypes, int *outstypes, void **in_indices, void **out_indices, void **in_indptr, void **out_indptr, int64_t *in_indices_shapes, int64_t *out_indices_shapes, int64_t *in_indptr_shapes, int64_t *out_indptr_shapes, void *rng_cpu_states, void *rng_gpu_states)
Definition: lib_api.h:988
MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char **strategy, mxnet::ext::supportedOps_t *supportedOps, mxnet::ext::createSelector_t *createSelector, mxnet::ext::reviewSubgraph_t *reviewSubgraph, const char **op_name)
returns partitioner registration at specified index
MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, const char *const *keys, const char *const *vals, int num, int *num_in, int *num_out)
returns status of calling parse attributes function for operator from library
void * data_ptr
Definition: lib_api.h:347