mxnet
kvstore.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_KVSTORE_H_
26 #define MXNET_KVSTORE_H_
27 #include <dmlc/io.h>
28 #include <vector>
29 #include <utility>
30 #include <unordered_map>
31 #include <string>
32 #include <functional>
33 #include <atomic>
34 #include "../../src/kvstore/gradient_compression.h"
35 #include "./ndarray.h"
36 #if MXNET_USE_DIST_KVSTORE
37 #include "ps/ps.h"
38 #endif // MXNET_USE_DIST_KVSTORE
39 
40 namespace mxnet {
41 
51 };
52 
59 class KVStore {
60  public:
62  virtual ~KVStore() {}
63 
74  static KVStore *Create(const char *type = "local");
75 
79  inline const std::string& type() { return type_; }
80 
86  virtual void SetGradientCompression(const std::vector<std::pair<std::string, std::string> >
87  & kwargs) = 0;
88 
105  virtual void Init(const std::vector<int>& keys,
106  const std::vector<NDArray>& values) = 0;
112  virtual void Init(const std::vector<std::string>& str_keys,
113  const std::vector<NDArray>& values) = 0;
150  virtual void Push(const std::vector<int>& keys,
151  const std::vector<NDArray>& values,
152  int priority = 0) = 0;
153 
160  virtual void Push(const std::vector<std::string>& str_keys,
161  const std::vector<NDArray>& values,
162  int priority = 0) = 0;
187  virtual void Pull(const std::vector<int>& keys,
188  const std::vector<NDArray*>& values,
189  int priority = 0, bool ignore_sparse = true) = 0;
197  virtual void Pull(const std::vector<std::string>& str_keys,
198  const std::vector<NDArray*>& values,
199  int priority = 0, bool ignore_sparse = true) = 0;
200 
209  virtual void PullRowSparse(const std::vector<int>& str_keys,
210  const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
211  int priority = 0) = 0;
212 
221  virtual void PullRowSparse(const std::vector<std::string>& str_keys,
222  const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
223  int priority = 0) = 0;
224 
228  typedef std::function<void(int, const NDArray&, NDArray*)> Updater;
232  typedef std::function<void(const std::string&, const NDArray&, NDArray*)> StrUpdater;
242  virtual void set_updater(const Updater& updater) {
243  CHECK(updater) << "invalid updater";
244  updater_ = updater;
245  }
246 
256  virtual void set_updater(const StrUpdater& updater) {
257  CHECK(updater) << "invalid updater";
258  str_updater_ = updater;
259  }
260 
261  /******************************************************
262  * the following are used for multi-machines.
263  ******************************************************/
264 
269  static void InitPSEnv(const std::unordered_map<std::string, std::string>& envs) {
270 #if MXNET_USE_DIST_KVSTORE
271  ps::Environment::Init(envs);
272 #else
273  LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to init parameter server's environment";
274 #endif // MXNET_USE_DIST_KVSTORE
275  }
276 
282  static bool IsWorkerNode() {
283 #if MXNET_USE_DIST_KVSTORE
284  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
285  return (role_str == nullptr) || (!strcmp(role_str, "worker"));
286 #else
287  return true;
288 #endif // MXNET_USE_DIST_KVSTORE
289  }
290 
296  static bool IsServerNode() {
297 #if MXNET_USE_DIST_KVSTORE
298  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
299  return (role_str != nullptr) && (!strcmp(role_str, "server"));
300 #else
301  return false;
302 #endif // MXNET_USE_DIST_KVSTORE
303  }
304 
305  void set_barrier_before_exit(const bool barrier_before_exit) {
306 #if MXNET_USE_DIST_KVSTORE
307  if (!IsWorkerNode()) LOG(FATAL) << "barrier_before_exit takes effect only on worker nodes";
308  barrier_before_exit_ = barrier_before_exit;
309 #else
310  LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to enable barrier";
311 #endif
312  }
313 
319  static bool IsSchedulerNode() {
320 #if MXNET_USE_DIST_KVSTORE
321  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
322  return (role_str != nullptr) && (!strcmp(role_str, "scheduler"));
323 #else
324  return false;
325 #endif // MXNET_USE_DIST_KVSTORE
326  }
327 
334  virtual int get_rank() const {
335  return 0;
336  }
337 
341  virtual int get_group_size() const {
342  return 1;
343  }
344 
353  virtual int get_num_dead_node(int node_id, int timeout = 60) const {
354  return 0;
355  }
356 
364  virtual void Barrier() { }
365 
377  virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) { }
378 
386  const std::string& params) {
387  LOG(INFO) << "Unable to pass server the profiler command. If you are using "
388  << "distributed kvstore, you need to compile with USE_DIST_KVSTORE=1."
389  << "If you are training on single machine, then there is no server process"
390  << "to profile. Please profile the worker process instead.";
391  }
392 
396  typedef std::function<void(int, const std::string&)> Controller;
397 
411  virtual void RunServer(const Controller& controller) { }
412 
413  protected:
417  Updater updater_;
418 
422  StrUpdater str_updater_;
423 
427  std::string type_;
428 
433  std::shared_ptr<kvstore::GradientCompression> gradient_compression_;
434 
438  std::atomic<bool> barrier_before_exit_{true};
439 };
440 
441 } // namespace mxnet
442 #endif // MXNET_KVSTORE_H_
distributed key-value store
Definition: kvstore.h:59
std::function< void(int, const NDArray &, NDArray *)> Updater
the prototype of user-defined updater
Definition: kvstore.h:228
namespace of mxnet
Definition: base.h:89
virtual int get_rank() const
Definition: kvstore.h:334
virtual void set_updater(const StrUpdater &updater)
set an updater with string keys
Definition: kvstore.h:256
Updater updater_
the user-defined updater
Definition: kvstore.h:417
const std::string & type()
return the type
Definition: kvstore.h:79
static bool IsSchedulerNode()
Definition: kvstore.h:319
virtual void Barrier()
global barrier among all worker machines
Definition: kvstore.h:364
static void InitPSEnv(const std::unordered_map< std::string, std::string > &envs)
initalize ps-lite environment variables
Definition: kvstore.h:269
static bool IsWorkerNode()
Definition: kvstore.h:282
virtual ~KVStore()
virtual destructor
Definition: kvstore.h:62
void set_barrier_before_exit(const bool barrier_before_exit)
Definition: kvstore.h:305
virtual void SetServerProfilerCommand(const KVStoreServerProfilerCommand type, const std::string &params)
Sends server profiler commands to all server nodes Only the worker with rank=0 sends the command whic...
Definition: kvstore.h:385
StrUpdater str_updater_
the user-defined updater with string keys
Definition: kvstore.h:422
virtual int get_num_dead_node(int node_id, int timeout=60) const
Definition: kvstore.h:353
std::shared_ptr< kvstore::GradientCompression > gradient_compression_
Gradient compression object starts with GC_NONE mode Used if SetGradientCompression sets the type...
Definition: kvstore.h:433
virtual void RunServer(const Controller &controller)
Run as server (or scheduler)
Definition: kvstore.h:411
std::function< void(const std::string &, const NDArray &, NDArray *)> StrUpdater
the prototype of user-defined updater with string keys
Definition: kvstore.h:232
virtual void SendCommandToServers(int cmd_id, const std::string &cmd_body)
Send a command to all server nodes.
Definition: kvstore.h:377
std::string type_
the kvstore type
Definition: kvstore.h:427
std::function< void(int, const std::string &)> Controller
the prototype of a server controller
Definition: kvstore.h:396
virtual void set_updater(const Updater &updater)
set an updater
Definition: kvstore.h:242
virtual int get_group_size() const
Definition: kvstore.h:341
KVStoreServerProfilerCommand
enum to denote types of commands kvstore sends to server regarding profiler kSetConfig sets profiler ...
Definition: kvstore.h:49
static bool IsServerNode()
Definition: kvstore.h:296