mxnet
exec_utils.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_COMMON_EXEC_UTILS_H_
25 #define MXNET_COMMON_EXEC_UTILS_H_
26 
27 #include <vector>
28 #include "../common/utils.h"
29 
30 namespace mxnet {
31 namespace common {
32 
33 /*
34  * \brief setup default-storage tblobs from source NDArrays. If any source NDArray has non-default
35  * storage, it creates a temp NDArray with default storage and uses the temp tblob. The
36  * function also records the indices of non-default source NDArrays and the indices of
37  * their corresponding temporary NDArrays in the temp array.
38  * \param src list of source NDArray
39  * \param blobs list of tblobs to return
40  * \param temp_src list of source NDArrays which requires temporary default storage representation
41  * \param temp_dst list of temporary destination NDArrays for default storage representation
42  * \param idx_map mapping from indices in source NDArrays to indices in temp_dst. When not set,
43  indices are not recorded
44  * \return true if any source NDArray need to cast storage
45  */
46 inline bool SetupDefaultBlobs(const std::vector<NDArray>& src,
47  std::vector<TBlob> *blobs,
48  std::vector<NDArray> *temp_src,
49  std::vector<NDArray> *temp_dst,
50  std::unordered_map<uint32_t, uint32_t> *idx_map = nullptr) {
51  bool require_cast = false;
52  for (size_t i = 0; i < src.size(); i++) {
53  auto& nd = src[i];
54  if (nd.storage_type() != kDefaultStorage) {
55  if (idx_map != nullptr) {
56  (*idx_map)[i] = temp_dst->size();
57  }
58  NDArray temp(nd.shape(), nd.ctx(), false, nd.dtype());
59  temp_src->emplace_back(nd);
60  temp_dst->emplace_back(temp);
61  blobs->emplace_back(temp.data());
62  require_cast = true;
63  } else {
64  blobs->push_back(nd.data());
65  }
66  }
67  return require_cast;
68 }
69 
70 /*
71  * \brief setup default-storage tblobs for input and output NDArrays.
72  * If any NDArray has non-default storage,
73  * it creates a temp NDArray with default storage and uses the temp tblob. The
74  * function also records the indices of non-default source NDArrays and the indices of
75  * their corresponding temporary NDArrays in the temp array.
76  */
77 inline void SetupDefaultBlobsInOut(const std::vector<NDArray> &ndinputs,
78  const std::vector<NDArray> &ndoutputs,
79  std::vector<TBlob> *input_blobs,
80  std::vector<TBlob> *output_blobs,
81  std::vector<NDArray> *pre_temp_src,
82  std::vector<NDArray> *pre_temp_dst,
83  std::vector<NDArray> *post_temp_src,
84  std::vector<NDArray> *post_temp_dst,
85  std::unordered_map<uint32_t, uint32_t> *in_temp_idx_map,
86  const std::vector<uint32_t> &mutate_idx) {
87  // populate input blobs
88  SetupDefaultBlobs(ndinputs, input_blobs, pre_temp_src, pre_temp_dst, in_temp_idx_map);
89  // populate output blobs
90  SetupDefaultBlobs(ndoutputs, output_blobs, post_temp_dst, post_temp_src);
91  // add mutable inputs to post temp list
92  for (const auto idx : mutate_idx) {
93  auto map_iter = in_temp_idx_map->find(idx);
94  if (map_iter != in_temp_idx_map->end()) {
95  post_temp_src->push_back(pre_temp_dst->at(map_iter->second));
96  post_temp_dst->push_back(ndinputs[idx]);
97  }
98  }
99 }
100 
101 /*
102  * \brief cast the NDArrays in `src` and store the result in NDArrays in `dst`.
103  * This is only used for storage fallback in executor.
104  * \param src list of source NDArray to cast
105  * \param dst list of destionation NDArray which hold the result of cast_storage operation
106  * \param ctx operator context for cast_storage operation
107  */
108 inline void CastNonDefaultStorage(const std::vector<NDArray>& src,
109  const std::vector<NDArray>& dst,
110  const OpContext& ctx,
111  const bool is_gpu) {
112  CHECK_EQ(dst.size(), src.size());
113  for (size_t i = 0; i < src.size(); i++) {
114  if (is_gpu) {
115 #if MXNET_USE_CUDA
116  CastStorageDispatch<gpu>(ctx, src[i], dst[i]);
117 #else
118  LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
119 #endif
120  } else {
121  CastStorageDispatch<cpu>(ctx, src[i], dst[i]);
122  }
123  }
124 }
125 } // namespace common
126 } // namespace mxnet
127 #endif // MXNET_COMMON_EXEC_UTILS_H_
Definition: ndarray.h:61
#define MXNET_GPU_NOT_ENABLED_ERROR
Error message for using gpu when MXNET_USE_CUDA==0.
Definition: base.h:68
namespace of mxnet
Definition: base.h:127
All the possible information needed by Operator.Forward and Backward This is the superset of RunConte...
Definition: op_attr_types.h:66
bool SetupDefaultBlobs(const std::vector< NDArray > &src, std::vector< TBlob > *blobs, std::vector< NDArray > *temp_src, std::vector< NDArray > *temp_dst, std::unordered_map< uint32_t, uint32_t > *idx_map=nullptr)
Definition: exec_utils.h:46
void SetupDefaultBlobsInOut(const std::vector< NDArray > &ndinputs, const std::vector< NDArray > &ndoutputs, std::vector< TBlob > *input_blobs, std::vector< TBlob > *output_blobs, std::vector< NDArray > *pre_temp_src, std::vector< NDArray > *pre_temp_dst, std::vector< NDArray > *post_temp_src, std::vector< NDArray > *post_temp_dst, std::unordered_map< uint32_t, uint32_t > *in_temp_idx_map, const std::vector< uint32_t > &mutate_idx)
Definition: exec_utils.h:77
void CastNonDefaultStorage(const std::vector< NDArray > &src, const std::vector< NDArray > &dst, const OpContext &ctx, const bool is_gpu)
Definition: exec_utils.h:108
ndarray interface
Definition: ndarray.h:79