mxnet
kvstore.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_KVSTORE_H_
25 #define MXNET_KVSTORE_H_
26 #include <dmlc/io.h>
27 #include <vector>
28 #include <utility>
29 #include <unordered_map>
30 #include <string>
31 #include <functional>
32 #include <atomic>
33 #include "./ndarray.h"
34 #if MXNET_USE_DIST_KVSTORE
35 #include "ps/ps.h"
36 #endif // MXNET_USE_DIST_KVSTORE
37 
38 namespace mxnet {
45 class KVStore {
46  public:
48  virtual ~KVStore() {}
49 
60  static KVStore *Create(const char *type = "local");
61 
65  inline const std::string& type() { return type_; }
66 
83  virtual void Init(const std::vector<int>& keys,
84  const std::vector<NDArray>& values) = 0;
90  virtual void Init(const std::vector<std::string>& str_keys,
91  const std::vector<NDArray>& values) = 0;
128  virtual void Push(const std::vector<int>& keys,
129  const std::vector<NDArray>& values,
130  int priority = 0) = 0;
131 
138  virtual void Push(const std::vector<std::string>& str_keys,
139  const std::vector<NDArray>& values,
140  int priority = 0) = 0;
164  virtual void Pull(const std::vector<int>& keys,
165  const std::vector<NDArray*>& values,
166  int priority = 0) = 0;
173  virtual void Pull(const std::vector<std::string>& str_keys,
174  const std::vector<NDArray*>& values,
175  int priority = 0) = 0;
176 
185  virtual void PullRowSparse(const std::vector<int>& str_keys,
186  const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
187  int priority = 0) = 0;
188 
197  virtual void PullRowSparse(const std::vector<std::string>& str_keys,
198  const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
199  int priority = 0) = 0;
200 
204  typedef std::function<void(int, const NDArray&, NDArray*)> Updater;
208  typedef std::function<void(const std::string&, const NDArray&, NDArray*)> StrUpdater;
218  virtual void set_updater(const Updater& updater) {
219  CHECK(updater) << "invalid updater";
220  updater_ = updater;
221  }
231  virtual void set_updater(const StrUpdater& updater) {
232  CHECK(updater) << "invalid updater";
233  str_updater_ = updater;
234  }
235 
236  /******************************************************
237  * the following are used for multi-machines.
238  ******************************************************/
239 
244  static void InitPSEnv(const std::unordered_map<std::string, std::string>& envs) {
245 #if MXNET_USE_DIST_KVSTORE
246  ps::Environment::Init(envs);
247 #else
248  LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to init parameter server's environment";
249 #endif // MXNET_USE_DIST_KVSTORE
250  }
251 
257  static bool IsWorkerNode() {
258 #if MXNET_USE_DIST_KVSTORE
259  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
260  return (role_str == nullptr) || (!strcmp(role_str, "worker"));
261 #else
262  return true;
263 #endif // MXNET_USE_DIST_KVSTORE
264  }
265 
271  static bool IsServerNode() {
272 #if MXNET_USE_DIST_KVSTORE
273  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
274  return (role_str != nullptr) && (!strcmp(role_str, "server"));
275 #else
276  return false;
277 #endif // MXNET_USE_DIST_KVSTORE
278  }
279 
280  void set_barrier_before_exit(const bool barrier_before_exit) {
281 #if MXNET_USE_DIST_KVSTORE
282  if (!IsWorkerNode()) LOG(FATAL) << "barrier_before_exit takes effect only on worker nodes";
283  barrier_before_exit_ = barrier_before_exit;
284 #else
285  LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to enable barrier";
286 #endif
287  }
288 
294  static bool IsSchedulerNode() {
295 #if MXNET_USE_DIST_KVSTORE
296  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
297  return (role_str != nullptr) && (!strcmp(role_str, "scheduler"));
298 #else
299  return false;
300 #endif // MXNET_USE_DIST_KVSTORE
301  }
302 
309  virtual int get_rank() const {
310  return 0;
311  }
312 
316  virtual int get_group_size() const {
317  return 1;
318  }
319 
328  virtual int get_num_dead_node(int node_id, int timeout = 60) const {
329  return 0;
330  }
331 
339  virtual void Barrier() { }
340 
352  virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) { }
353 
357  typedef std::function<void(int, const std::string&)> Controller;
358 
372  virtual void RunServer(const Controller& controller) { }
373 
374  protected:
378  Updater updater_;
379 
383  StrUpdater str_updater_;
384 
388  std::string type_;
389 
393  std::atomic<bool> barrier_before_exit_{true};
394 };
395 
396 } // namespace mxnet
397 #endif // MXNET_KVSTORE_H_
distributed key-value store
Definition: kvstore.h:45
std::function< void(int, const NDArray &, NDArray *)> Updater
the prototype of user-defined updater
Definition: kvstore.h:204
namespace of mxnet
Definition: base.h:126
virtual int get_rank() const
Definition: kvstore.h:309
static KVStore * Create(const char *type="local")
Factory function to create a new KVStore.
virtual void set_updater(const StrUpdater &updater)
set an updater with string keys
Definition: kvstore.h:231
Updater updater_
the user-defined updater
Definition: kvstore.h:378
virtual void PullRowSparse(const std::vector< int > &str_keys, const std::vector< std::pair< NDArray *, NDArray >> &val_rowids, int priority=0)=0
pull a list of key-value pairs from the store. The NDArray pulled back will be in row_sparse storage ...
const std::string & type()
return the type
Definition: kvstore.h:65
virtual void Pull(const std::vector< int > &keys, const std::vector< NDArray * > &values, int priority=0)=0
pull a list of key-value pairs from the store
static bool IsSchedulerNode()
Definition: kvstore.h:294
virtual void Barrier()
global barrier among all worker machines
Definition: kvstore.h:339
static void InitPSEnv(const std::unordered_map< std::string, std::string > &envs)
initalize ps-lite environment variables
Definition: kvstore.h:244
virtual void Init(const std::vector< int > &keys, const std::vector< NDArray > &values)=0
Initialize a list of key-value pair to the store.
static bool IsWorkerNode()
Definition: kvstore.h:257
virtual ~KVStore()
virtual destructor
Definition: kvstore.h:48
void set_barrier_before_exit(const bool barrier_before_exit)
Definition: kvstore.h:280
StrUpdater str_updater_
the user-defined updater with string keys
Definition: kvstore.h:383
virtual int get_num_dead_node(int node_id, int timeout=60) const
Definition: kvstore.h:328
virtual void RunServer(const Controller &controller)
Run as server (or scheduler)
Definition: kvstore.h:372
std::function< void(const std::string &, const NDArray &, NDArray *)> StrUpdater
the prototype of user-defined updater with string keys
Definition: kvstore.h:208
virtual void Push(const std::vector< int > &keys, const std::vector< NDArray > &values, int priority=0)=0
push a list of key-value pairs into the store
virtual void SendCommandToServers(int cmd_id, const std::string &cmd_body)
Send a command to all server nodes.
Definition: kvstore.h:352
std::string type_
the kvstore type
Definition: kvstore.h:388
std::function< void(int, const std::string &)> Controller
the prototype of a server controller
Definition: kvstore.h:357
virtual void set_updater(const Updater &updater)
set an updater
Definition: kvstore.h:218
virtual int get_group_size() const
Definition: kvstore.h:316
std::atomic< bool > barrier_before_exit_
whether to do barrier when finalize
Definition: kvstore.h:393
static bool IsServerNode()
Definition: kvstore.h:271