25 #ifndef MXNET_CPP_CONTRIB_H_ 26 #define MXNET_CPP_CONTRIB_H_ 44 inline std::vector<std::string>
split(
const std::string& str,
const std::string& delimiter) {
45 std::vector<std::string> splitted;
48 while ((next = str.find(delimiter, last)) != std::string::npos) {
49 splitted.push_back(str.substr(last, next - last));
52 splitted.push_back(str.substr(last));
62 static const std::string TENSORRT_SUBGRAPH_PARAM_IDENTIFIER =
"subgraph_params_names";
65 static const std::string TENSORRT_SUBGRAPH_PARAM_PREFIX =
"subgraph_param_";
73 std::map<std::string, mxnet::cpp::NDArray> *argParams,
74 std::map<std::string, mxnet::cpp::NDArray> *auxParams) {
77 for (
mx_uint i = 0; i < numSymbol; ++i) {
78 std::map<std::string, std::string> attrs = internals[i].
ListAttributes();
79 if (attrs.find(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER) != attrs.end()) {
80 std::string new_params_names;
81 std::map<std::string, mxnet::cpp::NDArray> tensorrtParams;
83 attrs[TENSORRT_SUBGRAPH_PARAM_IDENTIFIER],
";");
84 for (
const auto& key : keys) {
85 if (argParams->find(key) != argParams->end()) {
86 new_params_names += key +
";";
87 tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*argParams)[key];
88 argParams->erase(key);
89 }
else if (auxParams->find(key) != auxParams->end()) {
90 new_params_names += key +
";";
91 tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*auxParams)[key];
92 auxParams->erase(key);
95 std::map<std::string, std::string> new_attrs = {};
96 for (
const auto& kv : tensorrtParams) {
98 uint64_t address =
reinterpret_cast<uint64_t
>(kv.second.GetHandle());
99 new_attrs[kv.first] = std::to_string(address);
101 if (!new_attrs.empty()) {
103 internals[i].
SetAttribute(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER,
104 new_params_names.substr(0, new_params_names.length() - 1));
114 #endif // MXNET_CPP_CONTRIB_H_
namespace of mxnet
Definition: api_registry.h:33
std::map< std::string, std::string > ListAttributes() const
Symbol GetInternals() const
save Symbol into a JSON string the symbol whose outputs are all the internals.
void SetAttributes(const std::map< std::string, std::string > &attrs)
set a series of key-value attribute to the symbol
mx_uint GetNumOutputs() const
void InitTensorRTParams(const mxnet::cpp::Symbol &symbol, std::map< std::string, mxnet::cpp::NDArray > *argParams, std::map< std::string, mxnet::cpp::NDArray > *auxParams)
Definition: contrib.h:72
void SetAttribute(const std::string &key, const std::string &value)
set key-value attribute to the symbol
std::vector< std::string > split(const std::string &str, const std::string &delimiter)
Definition: contrib.h:44
Symbol interface.
Definition: symbol.h:71
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:57