mxnet
kvstore.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_KVSTORE_H_
26 #define MXNET_KVSTORE_H_
27 #include <dmlc/io.h>
28 #include <vector>
29 #include <utility>
30 #include <unordered_map>
31 #include <string>
32 #include <functional>
33 #include <atomic>
34 #include "../../src/kvstore/gradient_compression.h"
35 #include "./ndarray.h"
36 #if MXNET_USE_DIST_KVSTORE
37 #include "ps/ps.h"
38 #endif // MXNET_USE_DIST_KVSTORE
39 
40 namespace mxnet {
47 class KVStore {
48  public:
50  virtual ~KVStore() {}
51 
62  static KVStore *Create(const char *type = "local");
63 
67  inline const std::string& type() { return type_; }
68 
74  virtual void SetGradientCompression(const std::vector<std::pair<std::string, std::string> >
75  & kwargs) = 0;
76 
93  virtual void Init(const std::vector<int>& keys,
94  const std::vector<NDArray>& values) = 0;
100  virtual void Init(const std::vector<std::string>& str_keys,
101  const std::vector<NDArray>& values) = 0;
138  virtual void Push(const std::vector<int>& keys,
139  const std::vector<NDArray>& values,
140  int priority = 0) = 0;
141 
148  virtual void Push(const std::vector<std::string>& str_keys,
149  const std::vector<NDArray>& values,
150  int priority = 0) = 0;
174  virtual void Pull(const std::vector<int>& keys,
175  const std::vector<NDArray*>& values,
176  int priority = 0) = 0;
183  virtual void Pull(const std::vector<std::string>& str_keys,
184  const std::vector<NDArray*>& values,
185  int priority = 0) = 0;
186 
195  virtual void PullRowSparse(const std::vector<int>& str_keys,
196  const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
197  int priority = 0) = 0;
198 
207  virtual void PullRowSparse(const std::vector<std::string>& str_keys,
208  const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
209  int priority = 0) = 0;
210 
214  typedef std::function<void(int, const NDArray&, NDArray*)> Updater;
218  typedef std::function<void(const std::string&, const NDArray&, NDArray*)> StrUpdater;
228  virtual void set_updater(const Updater& updater) {
229  CHECK(updater) << "invalid updater";
230  updater_ = updater;
231  }
241  virtual void set_updater(const StrUpdater& updater) {
242  CHECK(updater) << "invalid updater";
243  str_updater_ = updater;
244  }
245 
246  /******************************************************
247  * the following are used for multi-machines.
248  ******************************************************/
249 
254  static void InitPSEnv(const std::unordered_map<std::string, std::string>& envs) {
255 #if MXNET_USE_DIST_KVSTORE
256  ps::Environment::Init(envs);
257 #else
258  LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to init parameter server's environment";
259 #endif // MXNET_USE_DIST_KVSTORE
260  }
261 
267  static bool IsWorkerNode() {
268 #if MXNET_USE_DIST_KVSTORE
269  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
270  return (role_str == nullptr) || (!strcmp(role_str, "worker"));
271 #else
272  return true;
273 #endif // MXNET_USE_DIST_KVSTORE
274  }
275 
281  static bool IsServerNode() {
282 #if MXNET_USE_DIST_KVSTORE
283  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
284  return (role_str != nullptr) && (!strcmp(role_str, "server"));
285 #else
286  return false;
287 #endif // MXNET_USE_DIST_KVSTORE
288  }
289 
290  void set_barrier_before_exit(const bool barrier_before_exit) {
291 #if MXNET_USE_DIST_KVSTORE
292  if (!IsWorkerNode()) LOG(FATAL) << "barrier_before_exit takes effect only on worker nodes";
293  barrier_before_exit_ = barrier_before_exit;
294 #else
295  LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to enable barrier";
296 #endif
297  }
298 
304  static bool IsSchedulerNode() {
305 #if MXNET_USE_DIST_KVSTORE
306  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
307  return (role_str != nullptr) && (!strcmp(role_str, "scheduler"));
308 #else
309  return false;
310 #endif // MXNET_USE_DIST_KVSTORE
311  }
312 
319  virtual int get_rank() const {
320  return 0;
321  }
322 
326  virtual int get_group_size() const {
327  return 1;
328  }
329 
338  virtual int get_num_dead_node(int node_id, int timeout = 60) const {
339  return 0;
340  }
341 
349  virtual void Barrier() { }
350 
362  virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) { }
363 
367  typedef std::function<void(int, const std::string&)> Controller;
368 
382  virtual void RunServer(const Controller& controller) { }
383 
384  protected:
388  Updater updater_;
389 
393  StrUpdater str_updater_;
394 
398  std::string type_;
399 
404  std::shared_ptr<kvstore::GradientCompression> gradient_compression_;
405 
409  std::atomic<bool> barrier_before_exit_{true};
410 };
411 
412 } // namespace mxnet
413 #endif // MXNET_KVSTORE_H_
distributed key-value store
Definition: kvstore.h:47
std::function< void(int, const NDArray &, NDArray *)> Updater
the prototype of user-defined updater
Definition: kvstore.h:214
virtual void SetGradientCompression(const std::vector< std::pair< std::string, std::string > > &kwargs)=0
Set parameters to use low-bit compressed gradients.
namespace of mxnet
Definition: base.h:118
virtual int get_rank() const
Definition: kvstore.h:319
static KVStore * Create(const char *type="local")
Factory function to create a new KVStore.
virtual void set_updater(const StrUpdater &updater)
set an updater with string keys
Definition: kvstore.h:241
Updater updater_
the user-defined updater
Definition: kvstore.h:388
virtual void PullRowSparse(const std::vector< int > &str_keys, const std::vector< std::pair< NDArray *, NDArray >> &val_rowids, int priority=0)=0
pull a list of key-value pairs from the store. The NDArray pulled back will be in row_sparse storage ...
const std::string & type()
return the type
Definition: kvstore.h:67
virtual void Pull(const std::vector< int > &keys, const std::vector< NDArray * > &values, int priority=0)=0
pull a list of key-value pairs from the store
static bool IsSchedulerNode()
Definition: kvstore.h:304
virtual void Barrier()
global barrier among all worker machines
Definition: kvstore.h:349
static void InitPSEnv(const std::unordered_map< std::string, std::string > &envs)
initalize ps-lite environment variables
Definition: kvstore.h:254
virtual void Init(const std::vector< int > &keys, const std::vector< NDArray > &values)=0
Initialize a list of key-value pair to the store.
static bool IsWorkerNode()
Definition: kvstore.h:267
virtual ~KVStore()
virtual destructor
Definition: kvstore.h:50
void set_barrier_before_exit(const bool barrier_before_exit)
Definition: kvstore.h:290
StrUpdater str_updater_
the user-defined updater with string keys
Definition: kvstore.h:393
virtual int get_num_dead_node(int node_id, int timeout=60) const
Definition: kvstore.h:338
std::shared_ptr< kvstore::GradientCompression > gradient_compression_
Gradient compression object starts with GC_NONE mode Used if SetGradientCompression sets the type...
Definition: kvstore.h:404
virtual void RunServer(const Controller &controller)
Run as server (or scheduler)
Definition: kvstore.h:382
std::function< void(const std::string &, const NDArray &, NDArray *)> StrUpdater
the prototype of user-defined updater with string keys
Definition: kvstore.h:218
virtual void Push(const std::vector< int > &keys, const std::vector< NDArray > &values, int priority=0)=0
push a list of key-value pairs into the store
virtual void SendCommandToServers(int cmd_id, const std::string &cmd_body)
Send a command to all server nodes.
Definition: kvstore.h:362
std::string type_
the kvstore type
Definition: kvstore.h:398
std::function< void(int, const std::string &)> Controller
the prototype of a server controller
Definition: kvstore.h:367
virtual void set_updater(const Updater &updater)
set an updater
Definition: kvstore.h:228
virtual int get_group_size() const
Definition: kvstore.h:326
std::atomic< bool > barrier_before_exit_
whether to do barrier when finalize
Definition: kvstore.h:409
static bool IsServerNode()
Definition: kvstore.h:281