mxnet
lazy_alloc_array.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_COMMON_LAZY_ALLOC_ARRAY_H_
26 #define MXNET_COMMON_LAZY_ALLOC_ARRAY_H_
27 
28 #include <dmlc/logging.h>
29 #include <memory>
30 #include <mutex>
31 #include <array>
32 #include <vector>
33 #include <atomic>
34 
35 namespace mxnet {
36 namespace common {
37 
38 template<typename TElem>
40  public:
48  template<typename FCreate>
49  inline std::shared_ptr<TElem> Get(int index, FCreate creator);
54  template<typename FVisit>
55  inline void ForEach(FVisit fvisit);
57  inline void Clear();
58 
59  private:
60  template<typename SyncObject>
61  class unique_unlock {
62  public:
63  explicit unique_unlock(std::unique_lock<SyncObject> *lock)
64  : lock_(lock) {
65  if (lock_) {
66  lock_->unlock();
67  }
68  }
69  ~unique_unlock() {
70  if (lock_) {
71  lock_->lock();
72  }
73  }
74  private:
75  std::unique_lock<SyncObject> *lock_;
76  };
77 
79  static constexpr std::size_t kInitSize = 16;
81  std::mutex create_mutex_;
83  std::array<std::shared_ptr<TElem>, kInitSize> head_;
85  std::vector<std::shared_ptr<TElem> > more_;
87  std::atomic<bool> is_clearing_;
88 };
89 
90 template<typename TElem>
92  : is_clearing_(false) {
93 }
94 
95 // implementations
96 template<typename TElem>
97 template<typename FCreate>
98 inline std::shared_ptr<TElem> LazyAllocArray<TElem>::Get(int index, FCreate creator) {
99  CHECK_GE(index, 0);
100  size_t idx = static_cast<size_t>(index);
101  if (idx < kInitSize) {
102  std::shared_ptr<TElem> ptr = head_[idx];
103  if (ptr) {
104  return ptr;
105  } else {
106  std::lock_guard<std::mutex> lock(create_mutex_);
107  if (!is_clearing_.load()) {
108  std::shared_ptr<TElem> ptr = head_[idx];
109  if (ptr) {
110  return ptr;
111  }
112  ptr = head_[idx] = std::shared_ptr<TElem>(creator());
113  return ptr;
114  }
115  }
116  } else {
117  std::lock_guard<std::mutex> lock(create_mutex_);
118  if (!is_clearing_.load()) {
119  idx -= kInitSize;
120  if (more_.size() <= idx) {
121  more_.reserve(idx + 1);
122  while (more_.size() <= idx) {
123  more_.push_back(std::shared_ptr<TElem>(nullptr));
124  }
125  }
126  std::shared_ptr<TElem> ptr = more_[idx];
127  if (ptr) {
128  return ptr;
129  }
130  ptr = more_[idx] = std::shared_ptr<TElem>(creator());
131  return ptr;
132  }
133  }
134  return nullptr;
135 }
136 
137 template<typename TElem>
139  std::unique_lock<std::mutex> lock(create_mutex_);
140  is_clearing_.store(true);
141  // Currently, head_ and more_ never get smaller, so it's safe to
142  // iterate them outside of the lock. The loops should catch
143  // any growth which might happen when create_mutex_ is unlocked
144  for (size_t i = 0; i < head_.size(); ++i) {
145  std::shared_ptr<TElem> p = head_[i];
146  head_[i] = std::shared_ptr<TElem>(nullptr);
147  unique_unlock<std::mutex> unlocker(&lock);
148  p = std::shared_ptr<TElem>(nullptr);
149  }
150  for (size_t i = 0; i < more_.size(); ++i) {
151  std::shared_ptr<TElem> p = more_[i];
152  more_[i] = std::shared_ptr<TElem>(nullptr);
153  unique_unlock<std::mutex> unlocker(&lock);
154  p = std::shared_ptr<TElem>(nullptr);
155  }
156  more_.clear();
157  is_clearing_.store(false);
158 }
159 
160 template<typename TElem>
161 template<typename FVisit>
162 inline void LazyAllocArray<TElem>::ForEach(FVisit fvisit) {
163  std::lock_guard<std::mutex> lock(create_mutex_);
164  for (size_t i = 0; i < head_.size(); ++i) {
165  if (head_[i].get() != nullptr) {
166  fvisit(i, head_[i].get());
167  }
168  }
169  for (size_t i = 0; i < more_.size(); ++i) {
170  if (more_[i].get() != nullptr) {
171  fvisit(i + kInitSize, more_[i].get());
172  }
173  }
174 }
175 
176 } // namespace common
177 } // namespace mxnet
178 #endif // MXNET_COMMON_LAZY_ALLOC_ARRAY_H_
std::shared_ptr< TElem > Get(int index, FCreate creator)
Get element of corresponding index, if it is not created create by creator.
Definition: lazy_alloc_array.h:98
namespace of mxnet
Definition: api_registry.h:33
void Clear()
clear all the allocated elements in array
Definition: lazy_alloc_array.h:138
void ForEach(FVisit fvisit)
for each not null element of the array, call fvisit
Definition: lazy_alloc_array.h:162
Definition: lazy_alloc_array.h:39
LazyAllocArray()
Definition: lazy_alloc_array.h:91