mxnet
kvstore.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_KVSTORE_H_
25 #define MXNET_KVSTORE_H_
26 #include <dmlc/io.h>
27 #include <vector>
28 #include <utility>
29 #include <unordered_map>
30 #include <string>
31 #include <functional>
32 #include <atomic>
33 #include "../../src/kvstore/gradient_compression.h"
34 #include "./ndarray.h"
35 #if MXNET_USE_DIST_KVSTORE
36 #include "ps/ps.h"
37 #endif // MXNET_USE_DIST_KVSTORE
38 
39 namespace mxnet {
40 
50 };
51 
58 class KVStore {
59  public:
61  virtual ~KVStore() {}
62 
73  static KVStore *Create(const char *type = "local");
74 
78  inline const std::string& type() { return type_; }
79 
85  virtual void SetGradientCompression(const std::vector<std::pair<std::string, std::string> >
86  & kwargs) = 0;
87 
104  virtual void Init(const std::vector<int>& keys,
105  const std::vector<NDArray>& values) = 0;
111  virtual void Init(const std::vector<std::string>& str_keys,
112  const std::vector<NDArray>& values) = 0;
149  virtual void Push(const std::vector<int>& keys,
150  const std::vector<NDArray>& values,
151  int priority = 0) = 0;
152 
159  virtual void Push(const std::vector<std::string>& str_keys,
160  const std::vector<NDArray>& values,
161  int priority = 0) = 0;
186  virtual void Pull(const std::vector<int>& keys,
187  const std::vector<NDArray*>& values,
188  int priority = 0, bool ignore_sparse = true) = 0;
196  virtual void Pull(const std::vector<std::string>& str_keys,
197  const std::vector<NDArray*>& values,
198  int priority = 0, bool ignore_sparse = true) = 0;
199 
208  virtual void Broadcast(const std::vector<int>& vkeys,
209  const std::vector<int>& okeys,
210  const std::vector<NDArray>& values,
211  const std::vector<NDArray*>& outs,
212  int priority = 0) = 0;
213 
222  virtual void Broadcast(const std::vector<std::string>& str_vkeys,
223  const std::vector<std::string>& str_okeys,
224  const std::vector<NDArray>& values,
225  const std::vector<NDArray*>& outs,
226  int priority = 0) = 0;
227 
236  virtual void PushPull(const std::vector<int>& vkeys,
237  const std::vector<int>& okeys,
238  const std::vector<NDArray>& values,
239  const std::vector<NDArray*>& outs,
240  int priority = 0) = 0;
241 
250  virtual void PushPull(const std::vector<std::string>& str_vkeys,
251  const std::vector<std::string>& str_okeys,
252  const std::vector<NDArray>& values,
253  const std::vector<NDArray*>& outs,
254  int priority = 0) = 0;
263  virtual void PullRowSparse(const std::vector<int>& str_keys,
264  const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
265  int priority = 0) = 0;
266 
275  virtual void PullRowSparse(const std::vector<std::string>& str_keys,
276  const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
277  int priority = 0) = 0;
278 
282  typedef std::function<void(int, const NDArray&, NDArray*)> Updater;
286  typedef std::function<void(const std::string&, const NDArray&, NDArray*)> StrUpdater;
296  virtual void set_updater(const Updater& updater) {
297  CHECK(updater) << "invalid updater";
298  updater_ = updater;
299  }
300 
310  virtual void set_updater(const StrUpdater& updater) {
311  CHECK(updater) << "invalid updater";
312  str_updater_ = updater;
313  }
314 
315  /******************************************************
316  * the following are used for multi-machines.
317  ******************************************************/
318 
323  static void InitPSEnv(const std::unordered_map<std::string, std::string>& envs) {
324 #if MXNET_USE_DIST_KVSTORE
325  ps::Environment::Init(envs);
326 #else
327  LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to init parameter server's environment";
328 #endif // MXNET_USE_DIST_KVSTORE
329  }
330 
336  static bool IsWorkerNode() {
337 #if MXNET_USE_DIST_KVSTORE
338  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
339  return (role_str == nullptr) || (!strcmp(role_str, "worker"));
340 #else
341  return true;
342 #endif // MXNET_USE_DIST_KVSTORE
343  }
344 
350  static bool IsServerNode() {
351 #if MXNET_USE_DIST_KVSTORE
352  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
353  return (role_str != nullptr) && (!strcmp(role_str, "server"));
354 #else
355  return false;
356 #endif // MXNET_USE_DIST_KVSTORE
357  }
358 
359  void set_barrier_before_exit(const bool barrier_before_exit) {
360 #if MXNET_USE_DIST_KVSTORE
361  if (!IsWorkerNode()) LOG(FATAL) << "barrier_before_exit takes effect only on worker nodes";
362  barrier_before_exit_ = barrier_before_exit;
363 #else
364  LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to enable barrier";
365 #endif
366  }
367 
373  static bool IsSchedulerNode() {
374 #if MXNET_USE_DIST_KVSTORE
375  const char* role_str = ps::Environment::Get()->find("DMLC_ROLE");
376  return (role_str != nullptr) && (!strcmp(role_str, "scheduler"));
377 #else
378  return false;
379 #endif // MXNET_USE_DIST_KVSTORE
380  }
381 
388  virtual int get_rank() const {
389  return 0;
390  }
391 
395  virtual int get_group_size() const {
396  return 1;
397  }
398 
407  virtual int get_num_dead_node(int node_id, int timeout = 60) const {
408  return 0;
409  }
410 
418  virtual void Barrier() { }
419 
431  virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) { }
432 
440  const std::string& params) {
441  LOG(INFO) << "Unable to pass server the profiler command. If you are using "
442  << "distributed kvstore, you need to compile with USE_DIST_KVSTORE=1."
443  << "If you are training on single machine, then there is no server process"
444  << "to profile. Please profile the worker process instead.";
445  }
446 
450  typedef std::function<void(int, const std::string&)> Controller;
451 
465  virtual void RunServer(const Controller& controller) { }
466 
467  protected:
471  Updater updater_;
472 
476  StrUpdater str_updater_;
477 
481  std::string type_;
482 
487  std::shared_ptr<kvstore::GradientCompression> gradient_compression_;
488 
492  std::atomic<bool> barrier_before_exit_{true};
493 };
494 
495 } // namespace mxnet
496 #endif // MXNET_KVSTORE_H_
distributed key-value store
Definition: kvstore.h:58
std::function< void(int, const NDArray &, NDArray *)> Updater
the prototype of user-defined updater
Definition: kvstore.h:282
namespace of mxnet
Definition: api_registry.h:33
virtual int get_group_size() const
Definition: kvstore.h:395
virtual void set_updater(const StrUpdater &updater)
set an updater with string keys
Definition: kvstore.h:310
Updater updater_
the user-defined updater
Definition: kvstore.h:471
const std::string & type()
return the type
Definition: kvstore.h:78
static bool IsSchedulerNode()
Definition: kvstore.h:373
virtual void Barrier()
global barrier among all worker machines
Definition: kvstore.h:418
static void InitPSEnv(const std::unordered_map< std::string, std::string > &envs)
initalize ps-lite environment variables
Definition: kvstore.h:323
static bool IsWorkerNode()
Definition: kvstore.h:336
virtual ~KVStore()
virtual destructor
Definition: kvstore.h:61
void set_barrier_before_exit(const bool barrier_before_exit)
Definition: kvstore.h:359
virtual void SetServerProfilerCommand(const KVStoreServerProfilerCommand type, const std::string &params)
Sends server profiler commands to all server nodes Only the worker with rank=0 sends the command whic...
Definition: kvstore.h:439
StrUpdater str_updater_
the user-defined updater with string keys
Definition: kvstore.h:476
virtual int get_num_dead_node(int node_id, int timeout=60) const
Definition: kvstore.h:407
std::shared_ptr< kvstore::GradientCompression > gradient_compression_
Gradient compression object starts with GC_NONE mode Used if SetGradientCompression sets the type...
Definition: kvstore.h:487
virtual void RunServer(const Controller &controller)
Run as server (or scheduler)
Definition: kvstore.h:465
std::function< void(const std::string &, const NDArray &, NDArray *)> StrUpdater
the prototype of user-defined updater with string keys
Definition: kvstore.h:286
virtual int get_rank() const
Definition: kvstore.h:388
virtual void SendCommandToServers(int cmd_id, const std::string &cmd_body)
Send a command to all server nodes.
Definition: kvstore.h:431
std::string type_
the kvstore type
Definition: kvstore.h:481
std::function< void(int, const std::string &)> Controller
the prototype of a server controller
Definition: kvstore.h:450
virtual void set_updater(const Updater &updater)
set an updater
Definition: kvstore.h:296
KVStoreServerProfilerCommand
enum to denote types of commands kvstore sends to server regarding profiler kSetConfig sets profiler ...
Definition: kvstore.h:48
static bool IsServerNode()
Definition: kvstore.h:350