25 #ifndef MXNET_RUNTIME_PACKED_FUNC_H_ 26 #define MXNET_RUNTIME_PACKED_FUNC_H_ 28 #include <dmlc/logging.h> 46 #include <type_traits> 67 class MXNetArgsSetter;
97 using FType = std::function<void (MXNetArgs args, MXNetRetValue* rv)>;
121 template<
typename... Args>
133 return body_ ==
nullptr;
137 return body_ !=
nullptr;
148 template<
typename FType>
183 template<
typename R,
typename ...Args>
235 template<
typename FLambda,
236 typename =
typename std::enable_if<
237 std::is_convertible<FLambda,
238 std::function<R(Args...)>
241 this->AssignTypedLambda(typed_lambda);
259 template<
typename FLambda,
260 typename =
typename std::enable_if<
261 std::is_convertible<FLambda,
262 std::function<R(Args...)>
265 this->AssignTypedLambda(typed_lambda);
298 return packed_ ==
nullptr;
302 return packed_ !=
nullptr;
316 template<
typename FLambda>
317 inline void AssignTypedLambda(FLambda flambda);
333 const int* type_codes,
336 type_codes(type_codes),
337 num_args(num_args) { }
339 inline int size()
const;
363 #define MXNET_CHECK_TYPE_CODE(CODE, T) \ 364 CHECK_EQ(CODE, T) << " expected " \ 365 << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ 380 static const int code = 0;
389 operator double()
const {
393 if (type_code_ ==
kDLInt) {
394 return static_cast<double>(value_.v_int64);
397 return value_.v_float64;
399 operator int64_t()
const {
401 return value_.v_int64;
403 operator uint64_t()
const {
405 return value_.v_int64;
407 operator int()
const {
409 CHECK_LE(value_.v_int64,
410 std::numeric_limits<int>::max());
411 return static_cast<int>(value_.v_int64);
413 operator bool()
const {
415 return value_.v_int64 != 0;
417 operator void*()
const {
418 if (type_code_ ==
kNull)
return nullptr;
421 return value_.v_handle;
424 if (type_code_ ==
kNull) {
431 template<
typename TObjectRef,
432 typename =
typename std::enable_if<
433 std::is_class<TObjectRef>::value>::type>
434 inline bool IsObjectRef()
const;
446 return static_cast<T*
>(value_.v_handle);
454 : value_(value), type_code_(type_code) {}
481 using MXNetPODValue_::operator double;
482 using MXNetPODValue_::operator int64_t;
483 using MXNetPODValue_::operator uint64_t;
484 using MXNetPODValue_::operator int;
485 using MXNetPODValue_::operator bool;
486 using MXNetPODValue_::operator
void*;
487 using MXNetPODValue_::operator
ObjectRef;
491 operator std::string()
const {
492 if (type_code_ ==
kBytes) {
494 return std::string(arr->
data, arr->
size);
497 return std::string(value_.v_str);
501 if (type_code_ ==
kStr) {
505 if (type_code_ ==
kNull) {
511 return value_.v_type;
516 operator ::mxnet::NDArray*()
const {
517 if (type_code_ ==
kNull) {
526 return *ptr<PackedFunc>();
528 template<
typename FType>
536 template<
typename TObjectRef>
537 inline TObjectRef AsObjectRef()
const;
539 typename =
typename std::enable_if<
540 std::is_class<T>::value>::type>
541 inline operator T()
const;
563 other.type_code_ =
kNull;
570 using MXNetPODValue_::operator double;
571 using MXNetPODValue_::operator int64_t;
572 using MXNetPODValue_::operator uint64_t;
573 using MXNetPODValue_::operator int;
574 using MXNetPODValue_::operator bool;
575 using MXNetPODValue_::operator
void*;
576 using MXNetPODValue_::operator
ObjectRef;
583 operator std::string()
const {
584 if (type_code_ ==
kBytes) {
585 return *ptr<std::string>();
588 return *ptr<std::string>();
591 if (type_code_ ==
kStr) {
595 return value_.v_type;
603 return *ptr<PackedFunc>();
605 template<
typename FType>
612 value_ = other.value_;
613 type_code_ = other.type_code_;
614 other.type_code_ =
kNull;
619 value_.v_float64 = value;
623 this->SwitchToPOD(
kNull);
624 value_.v_handle = value;
629 value_.v_handle = value;
633 this->SwitchToPOD(
kDLInt);
634 value_.v_int64 = value;
638 this->SwitchToPOD(
kDLInt);
639 value_.v_int64 = value;
643 this->SwitchToPOD(
kDLInt);
644 value_.v_int64 = value;
648 this->SwitchToClass(
kStr, value);
657 return operator=(other.operator
DLDataType());
664 return operator=(std::move(other.
data_));
675 template<
typename FType>
677 return operator=(f.packed());
689 value_.v_handle =
reinterpret_cast<void*
>(value);
693 typename =
typename std::enable_if<
696 this->SwitchToClass<T>(
710 int* ret_type_code) {
712 CHECK(type_code_ !=
kStr && type_code_ !=
kBytes);
714 *ret_type_code = type_code_;
721 type_code_ !=
kStr) <<
"MXNetRetValue.value can only be used for POD data";
726 typename =
typename std::enable_if<
727 std::is_class<T>::value>::type>
728 inline operator T()
const;
729 template<
typename TObjectRef>
730 inline TObjectRef AsObjectRef()
const;
734 void Assign(
const T& other) {
735 switch (other.type_code()) {
737 SwitchToClass<std::string>(
kStr, other);
741 SwitchToClass<std::string>(
kBytes, other);
749 *
this = other.operator ObjectRef();
754 SwitchToPOD(other.type_code());
755 value_ = other.value_;
757 LOG(FATAL) <<
"Does not support ext type";
764 void SwitchToPOD(
int type_code) {
765 if (type_code_ != type_code) {
767 type_code_ = type_code;
771 void SwitchToClass(
int type_code, T v) {
772 if (type_code_ != type_code) {
774 type_code_ = type_code;
775 value_.v_handle =
new T(v);
777 *
static_cast<T*
>(value_.v_handle) = v;
781 if (other.data_ !=
nullptr) {
783 type_code_ = type_code;
785 value_.v_handle = other.data_;
786 other.data_ =
nullptr;
792 if (type_code_ ==
kNull)
return;
793 switch (type_code_) {
794 case kStr:
delete ptr<std::string>();
break;
797 static_cast<Object*
>(value_.v_handle)->DecRef();
802 LOG(FATAL) <<
"Does not support ext type";
811 if (s.length() == 0) {
816 const char* scan =
nullptr;
817 if (s.substr(0, 3) ==
"int") {
819 }
else if (s.substr(0, 4) ==
"uint") {
821 }
else if (s.substr(0, 5) ==
"float") {
823 }
else if (s.substr(0, 6) ==
"handle") {
826 scan = s.c_str() + 6;
827 }
else if (s ==
"bool") {
832 }
else if (s.substr(0, 6) ==
"custom") {
833 LOG(FATAL) <<
"custom MXNetDataType is not supported";
837 LOG(FATAL) <<
"unknown type " << s;
840 uint8_t bits =
static_cast<uint8_t
>(strtoul(scan, &xdelim, 10));
841 if (bits != 0) t.
bits = bits;
842 char* endpt = xdelim;
843 if (*xdelim ==
'x') {
844 t.
lanes =
static_cast<uint16_t
>(strtoul(xdelim + 1, &endpt, 10));
846 CHECK(endpt == s.c_str() + s.length()) <<
"unknown type " << s;
853 case kDLInt:
return "int";
856 case kStr:
return "str";
857 case kBytes:
return "bytes";
859 case kNull:
return "NULL";
862 default: LOG(FATAL) <<
"unknown type_code=" 863 <<
static_cast<int>(type_code);
return "";
868 if (s ==
"float32") {
870 }
else if (s ==
"float64") {
872 }
else if (s ==
"float16") {
874 }
else if (s ==
"uint8") {
876 }
else if (s ==
"int8") {
878 }
else if (s ==
"int32") {
880 }
else if (s ==
"int64") {
882 }
else if (s ==
"bool") {
885 LOG(FATAL) <<
"unknown type " << s;
887 LOG(FATAL) <<
"should not reach here ";
892 if (s ==
"float32") {
894 }
else if (s ==
"float64") {
896 }
else if (s ==
"float16") {
898 }
else if (s ==
"uint8") {
900 }
else if (s ==
"int8") {
902 }
else if (s ==
"int32") {
904 }
else if (s ==
"int64") {
907 LOG(FATAL) <<
"unknown type " << s;
909 LOG(FATAL) <<
"should not reach here ";
915 os <<
"bool";
return os;
920 LOG(FATAL) <<
"custom MXNetDataType is not supported";
924 os << static_cast<int>(t.
bits);
926 os << 'x' << static_cast<int>(t.
lanes);
936 CHECK_LT(i, num_args)
937 <<
"not enough argument passed, " 938 << num_args <<
" passed" 939 <<
" but request arg[" << i <<
"].";
958 template<
bool stop, std::
size_t I,
typename F>
960 template<
typename T,
typename ...Args>
961 static void run(
const F& f, T&& value, Args&&... args) {
962 f(I, std::forward<T>(value));
964 ::run(f, std::forward<Args>(args)...);
968 template<std::
size_t I,
typename F>
973 template<
typename F,
typename ...Args>
976 ::run(f, std::forward<Args>(args)...);
984 : values_(values), type_codes_(type_codes) {}
987 typename =
typename std::enable_if<
988 std::is_integral<T>::value>::type>
990 values_[i].v_int64 =
static_cast<int64_t
>(value);
994 values_[i].v_int64 =
static_cast<int64_t
>(value);
996 static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
1000 values_[i].v_float64 = value;
1004 values_[i].v_handle = value;
1005 type_codes_[i] =
kNull;
1008 values_[i] = value.
value_;
1012 values_[i].v_handle = value;
1016 values_[i].v_handle = value;
1020 values_[i].v_str = value;
1021 type_codes_[i] =
kStr;
1027 values_[i].v_str = value.c_str();
1028 type_codes_[i] =
kStr;
1031 values_[i].v_type = value;
1042 values_[i].v_handle =
const_cast<PackedFunc*
>(&value);
1045 template<
typename FType>
1051 values_[i].v_handle = value.
data_.data_;
1054 type_codes_[i] =
kNull;
1059 values_[i].v_str = value.
ptr<std::string>()->c_str();
1060 type_codes_[i] =
kStr;
1063 values_[i] = value.
value_;
1075 template<
typename... Args>
1077 const int kNumArgs =
sizeof...(Args);
1078 const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
1080 int type_codes[kArraySize];
1082 std::forward<Args>(args)...);
1084 body_(
MXNetArgs(values, type_codes, kNumArgs), &rv);
1089 template<
typename R,
int nleft,
int index,
typename F>
1091 template<
typename ...Args>
1095 Args&&... unpacked_args) {
1097 ::run(f, args_pack, rv,
1098 std::forward<Args>(unpacked_args)...,
1103 template<
typename R,
int index,
typename F>
1105 template<
typename ...Args>
1109 Args&&... unpacked_args) {
1110 *rv = R(f(std::forward<Args>(unpacked_args)...));
1114 template<
int index,
typename F>
1116 template<
typename ...Args>
1120 Args&&... unpacked_args) {
1121 f(std::forward<Args>(unpacked_args)...);
1125 template<
typename R,
int nargs,
typename F>
1130 template<
typename R,
typename ...Args>
1132 return R(pf(std::forward<Args>(args)...));
1135 template<
typename R>
1137 template<
typename ...Args>
1139 return pf(std::forward<Args>(args)...);
1145 template<
typename ...Args>
1147 pf(std::forward<Args>(args)...);
1152 template<
typename R,
typename ...Args>
1154 : packed_(packed) {}
1156 template<
typename R,
typename ...Args>
1160 template<
typename R,
typename ...Args>
1164 template<
typename R,
typename ...Args>
1165 template<
typename FType>
1172 template<
typename R,
typename ...Args>
1175 ::run(packed_, std::forward<Args>(args)...);
1180 template<
typename T,
typename TSrc,
bool is_ext,
bool is_nd>
1183 static_assert(!is_ext && !is_nd,
"The default case accepts only non-extensions");
1184 return self->template AsObjectRef<T>();
1190 template<
typename T,
typename>
1191 inline MXNetRetValue::operator T()
const {
1201 #endif // MXNET_RUNTIME_PACKED_FUNC_H_ MXNetArgValue()
default constructor
Definition: packed_func.h:471
Definition: c_runtime_api.h:46
Definition: packed_func.h:1181
MXNetRetValue & operator=(int value)
Definition: packed_func.h:637
void operator()(size_t i, const MXNetArgValue &value) const
Definition: packed_func.h:1007
void operator()(size_t i, DLTensor *value) const
Definition: packed_func.h:1015
void * v_handle
Definition: c_runtime_api.h:79
Definition: c_runtime_api.h:62
MXNetPODValue_()
Definition: packed_func.h:452
MXNetRetValue & operator=(std::string value)
Definition: packed_func.h:647
static void run(const F &f, const MXNetArgs &args_pack, MXNetRetValue *rv, Args &&...unpacked_args)
Definition: packed_func.h:1117
Definition: packed_func.h:1090
std::function< void(MXNetArgs args, MXNetRetValue *rv)> FType
The internal std::function.
Definition: packed_func.h:97
Definition: c_runtime_api.h:69
Definition: c_runtime_api.h:47
The type trait indicates subclass of TVM's NDArray. For irrelavant classes, code = -1...
Definition: ndarray.h:38
DLDataType String2DLDataType(std::string s)
convert a string to TVM type.
Definition: packed_func.h:808
void operator()(size_t i, const ObjectRef &value) const
Definition: packed_func.h:1049
namespace of mxnet
Definition: api_registry.h:33
MXNetRetValue & operator=(const MXNetDataType &other)
Definition: packed_func.h:656
int size() const
Definition: packed_func.h:943
MXNetRetValue & operator=(PackedFunc f)
Definition: packed_func.h:671
void CallPacked(MXNetArgs args, MXNetRetValue *rv) const
Call the function in packed format.
Definition: packed_func.h:947
void operator()(size_t i, const PackedFunc &value) const
Definition: packed_func.h:1041
const char * TypeCode2Str(int type_code)
Convert type code to its name.
Definition: packed_func.h:851
Arguments into TVM functions.
Definition: packed_func.h:321
MXNetRetValue & operator=(MXNetRetValue &&other)
Definition: packed_func.h:610
void operator()(size_t i, const std::string &value) const
Definition: packed_func.h:1026
void operator()(size_t i, T value) const
Definition: packed_func.h:989
int String2MXNetType(const std::string &s)
Definition: packed_func.h:891
A custom smart pointer for Object.
Definition: object.h:345
MXNetArgsSetter(MXNetValue *values, int *type_codes)
Definition: packed_func.h:983
void operator()(size_t i, const TypedPackedFunc< FType > &value) const
Definition: packed_func.h:1046
TypedPackedFunc()
default constructor
Definition: packed_func.h:189
Definition: c_runtime_api.h:53
void operator()(size_t i, const MXNetByteArray &value) const
Definition: packed_func.h:1037
void operator()(size_t i, std::nullptr_t value) const
Definition: packed_func.h:1003
MXNetRetValue(MXNetRetValue &&other)
move constructor from anoter return value.
Definition: packed_func.h:560
int type_code_
the type code
Definition: packed_func.h:459
MXNetRetValue & operator=(const T &other)
Definition: packed_func.h:695
MXNetRetValue & operator=(MXNetByteArray value)
Definition: packed_func.h:659
Base class of all object reference.
Definition: object.h:499
void operator()(size_t i, const MXNetRetValue &value) const
Definition: packed_func.h:1057
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:132
Packed function is a type-erased function. The arguments are passed by packed format.
Definition: packed_func.h:77
uint8_t code
Type code of base types. We keep it uint8_t instead of DLDataTypeCode for minimal memory footprint...
Definition: dlpack.h:100
Definition: c_runtime_api.h:50
void operator()(size_t i, MXNetDataType dtype) const
Definition: packed_func.h:1034
void operator()(size_t i, double value) const
Definition: packed_func.h:999
MXNetRetValue & operator=(bool value)
Definition: packed_func.h:642
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:301
Definition: packed_func.h:1136
MXNetRetValue & operator=(std::nullptr_t value)
Definition: packed_func.h:622
MXNetRetValue()
default constructor
Definition: packed_func.h:555
void operator()(size_t i, const char *value) const
Definition: packed_func.h:1019
MXNetRetValue & operator=(const TypedPackedFunc< FType > &f)
Definition: packed_func.h:676
Definition: packed_func.h:981
A single argument value to PackedFunc. Containing both type_code and MXNetValue.
Definition: packed_func.h:468
MXNetPODValue_(MXNetValue value, int type_code)
Definition: packed_func.h:453
MXNetArgValue operator[](int i) const
Get i-th argument.
Definition: packed_func.h:935
R call_packed(const PackedFunc &pf, Args &&...args)
Definition: packed_func.h:1131
const MXNetValue & value() const
Definition: packed_func.h:718
const MXNetValue * values
Definition: packed_func.h:323
TSelf & operator=(FLambda typed_lambda)
copy assignment operator from typed lambda
Definition: packed_func.h:264
Please refer to TypedPackedFunc<R(Args..)>.
Definition: packed_func.h:149
uint8_t bits
Number of bits, common choices are 8, 16, 32.
Definition: dlpack.h:104
void operator()(size_t i, uint64_t value) const
Definition: packed_func.h:993
MXNetRetValue & operator=(int64_t value)
Definition: packed_func.h:632
PackedFunc(FType body)
constructing a packed function from a std::function.
Definition: packed_func.h:106
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:72
Type traits to mark if a class is tvm extension type.
Definition: packed_func.h:379
~MXNetRetValue()
destructor
Definition: packed_func.h:566
base class of all object containers.
Definition: object.h:149
MXNetRetValue & operator=(double value)
Definition: packed_func.h:617
Runtime primitive data type.
Definition: data_type.h:41
void operator()(size_t i, void *value) const
Definition: packed_func.h:1011
void unpack_call(const F &f, const MXNetArgs &args, MXNetRetValue *rv)
Definition: packed_func.h:1126
ObjectPtr< Object > data_
Internal pointer that backs the reference.
Definition: object.h:572
static void run(const F &f, const MXNetArgs &args_pack, MXNetRetValue *rv, Args &&...unpacked_args)
Definition: packed_func.h:1106
MXNetRetValue & operator=(DLDataType t)
Definition: packed_func.h:651
static void run(const F &f, T &&value, Args &&...args)
Definition: packed_func.h:961
PackedFunc()
default constructor
Definition: packed_func.h:99
bool operator!=(std::nullptr_t null) const
Definition: packed_func.h:136
MXNetValue value_
The value.
Definition: packed_func.h:457
#define MXNET_CHECK_TYPE_CODE(CODE, T)
convert a string to TVM type.
Definition: packed_func.h:363
T * ptr() const
return handle as specific pointer type.
Definition: packed_func.h:445
static void run(const F &f, const MXNetArgs &args_pack, MXNetRetValue *rv, Args &&...unpacked_args)
Definition: packed_func.h:1092
A device-independent managed NDArray abstraction.
A managed object in MXNet runtime.
bool defined() const
Definition: object.h:538
Union type of values being passed through API and function calls.
Definition: c_runtime_api.h:76
void MoveToCHost(MXNetValue *ret_value, int *ret_type_code)
Move the value back to front-end via C API. This marks the current container as null. The managed resources is moved to front-end and the front end should take charge in managing them.
Definition: packed_func.h:709
FType body() const
Definition: packed_func.h:951
const PackedFunc & packed() const
Definition: packed_func.h:293
Definition: c_runtime_api.h:55
void operator()(size_t i, DLDataType value) const
Definition: packed_func.h:1030
MXNetRetValue & operator=(ObjectRef other)
Definition: packed_func.h:663
static R run(const PackedFunc &pf, Args &&...args)
Definition: packed_func.h:1138
MXNetRetValue & operator=(void *value)
Definition: packed_func.h:627
std::ostream & operator<<(std::ostream &os, DLDataType t)
Definition: packed_func.h:913
Definition: c_runtime_api.h:51
Byte array type used to pass in byte array When kBytes is used as data type.
Definition: c_runtime_api.h:88
MXNetRetValue & operator=(ObjectPtr< T > other)
Definition: packed_func.h:667
const int * type_codes
Definition: packed_func.h:324
int type_code() const
Definition: packed_func.h:435
static T Apply(const TSrc *self)
Definition: packed_func.h:1182
Base expr nodes in MXNet.
MXNetRetValue & operator=(const MXNetRetValue &other)
Definition: packed_func.h:679
Definition: c_runtime_api.h:57
Definition: c_runtime_api.h:54
MXNetRetValue & operator=(const MXNetArgValue &other)
Definition: packed_func.h:683
bool operator==(std::nullptr_t null) const
Definition: packed_func.h:297
Definition: packed_func.h:959
MXNetArgs(const MXNetValue *values, const int *type_codes, int num_args)
constructor
Definition: packed_func.h:332
const char * data
Definition: c_runtime_api.h:89
TSelf & operator=(PackedFunc packed)
copy assignment operator from PackedFunc.
Definition: packed_func.h:273
TypedPackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:191
void for_each(const F &f, Args &&...args)
Definition: packed_func.h:974
Internal base class to handle conversion to POD values.
Definition: packed_func.h:387
MXNetRetValue(const MXNetRetValue &other)
Definition: packed_func.h:579
MXNetRetValue & operator=(::mxnet::NDArray *value)
Definition: packed_func.h:687
runtime::MXNetDataType MXNetDataType
Definition: data_type.h:214
The data type the tensor can hold.
Definition: dlpack.h:94
uint16_t lanes
Number of lanes in the type, used for vector types.
Definition: dlpack.h:106
Plain C Tensor object, does not manage memory.
Definition: dlpack.h:112
static void run(const PackedFunc &pf, Args &&...args)
Definition: packed_func.h:1146
size_t size
Definition: c_runtime_api.h:90
ndarray interface
Definition: ndarray.h:82
PackedFunc(std::nullptr_t null)
constructor from null
Definition: packed_func.h:101
MXNetArgValue(MXNetValue value, int type_code)
constructor
Definition: packed_func.h:477
MXNetRetValue operator()(Args &&...args) const
Call packed function by directly passing in unpacked format.
Definition: packed_func.h:1076
Return Value container, Unlike MXNetArgValue, which only holds reference and do not delete the underl...
Definition: packed_func.h:552
const MXNetValue & value() const
Definition: packed_func.h:532
static void run(const F &f)
Definition: packed_func.h:970
int String2MXNetTypeWithBool(const std::string &s)
Definition: packed_func.h:867
Definition: c_runtime_api.h:48
int num_args
Definition: packed_func.h:325
TypedPackedFunc(const FLambda &typed_lambda)
construct from a lambda function with the same signature.
Definition: packed_func.h:240
A PackedFunc wrapper to provide typed function signature. It is backed by a PackedFunc internally...
Definition: packed_func.h:184