mxnet
contrib.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_CPP_CONTRIB_H_
26 #define MXNET_CPP_CONTRIB_H_
27 
28 #include <iostream>
29 #include <string>
30 #include <map>
31 #include <vector>
32 #include "mxnet-cpp/symbol.h"
33 
34 namespace mxnet {
35 namespace cpp {
36 namespace details {
37 
44  inline std::vector<std::string> split(const std::string& str, const std::string& delimiter) {
45  std::vector<std::string> splitted;
46  size_t last = 0;
47  size_t next = 0;
48  while ((next = str.find(delimiter, last)) != std::string::npos) {
49  splitted.push_back(str.substr(last, next - last));
50  last = next + 1;
51  }
52  splitted.push_back(str.substr(last));
53  return splitted;
54  }
55 
56 } // namespace details
57 
58 namespace contrib {
59 
60  // needs to be same with
61  // https://github.com/apache/mxnet/blob/1c874cfc807cee755c38f6486e8e0f4d94416cd8/src/operator/subgraph/tensorrt/tensorrt-inl.h#L190
62  static const std::string TENSORRT_SUBGRAPH_PARAM_IDENTIFIER = "subgraph_params_names";
63  // needs to be same with
64  // https://github.com/apache/mxnet/blob/master/src/operator/subgraph/tensorrt/tensorrt.cc#L244
65  static const std::string TENSORRT_SUBGRAPH_PARAM_PREFIX = "subgraph_param_";
72  inline void InitTensorRTParams(const mxnet::cpp::Symbol& symbol,
73  std::map<std::string, mxnet::cpp::NDArray> *argParams,
74  std::map<std::string, mxnet::cpp::NDArray> *auxParams) {
75  mxnet::cpp::Symbol internals = symbol.GetInternals();
76  mx_uint numSymbol = internals.GetNumOutputs();
77  for (mx_uint i = 0; i < numSymbol; ++i) {
78  std::map<std::string, std::string> attrs = internals[i].ListAttributes();
79  if (attrs.find(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER) != attrs.end()) {
80  std::string new_params_names;
81  std::map<std::string, mxnet::cpp::NDArray> tensorrtParams;
82  std::vector<std::string> keys = details::split(
83  attrs[TENSORRT_SUBGRAPH_PARAM_IDENTIFIER], ";");
84  for (const auto& key : keys) {
85  if (argParams->find(key) != argParams->end()) {
86  new_params_names += key + ";";
87  tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*argParams)[key];
88  argParams->erase(key);
89  } else if (auxParams->find(key) != auxParams->end()) {
90  new_params_names += key + ";";
91  tensorrtParams[TENSORRT_SUBGRAPH_PARAM_PREFIX + key] = (*auxParams)[key];
92  auxParams->erase(key);
93  }
94  }
95  std::map<std::string, std::string> new_attrs = {};
96  for (const auto& kv : tensorrtParams) {
97  // passing the ndarray address into TRT node attributes to get the weight
98  uint64_t address = reinterpret_cast<uint64_t>(kv.second.GetHandle());
99  new_attrs[kv.first] = std::to_string(address);
100  }
101  if (!new_attrs.empty()) {
102  internals[i].SetAttributes(new_attrs);
103  internals[i].SetAttribute(TENSORRT_SUBGRAPH_PARAM_IDENTIFIER,
104  new_params_names.substr(0, new_params_names.length() - 1));
105  }
106  }
107  }
108 }
109 
110 } // namespace contrib
111 } // namespace cpp
112 } // namespace mxnet
113 
114 #endif // MXNET_CPP_CONTRIB_H_
definition of symbol
namespace of mxnet
Definition: api_registry.h:33
std::map< std::string, std::string > ListAttributes() const
Symbol GetInternals() const
save Symbol into a JSON string the symbol whose outputs are all the internals.
void SetAttributes(const std::map< std::string, std::string > &attrs)
set a series of key-value attribute to the symbol
mx_uint GetNumOutputs() const
void InitTensorRTParams(const mxnet::cpp::Symbol &symbol, std::map< std::string, mxnet::cpp::NDArray > *argParams, std::map< std::string, mxnet::cpp::NDArray > *auxParams)
Definition: contrib.h:72
void SetAttribute(const std::string &key, const std::string &value)
set key-value attribute to the symbol
std::vector< std::string > split(const std::string &str, const std::string &delimiter)
Definition: contrib.h:44
Symbol interface.
Definition: symbol.h:71
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:57