24 #ifndef MXNET_KVSTORE_H_ 25 #define MXNET_KVSTORE_H_ 29 #include <unordered_map> 33 #include "../../src/kvstore/gradient_compression.h" 35 #if MXNET_USE_DIST_KVSTORE 37 #endif // MXNET_USE_DIST_KVSTORE 73 static KVStore *Create(
const char *type =
"local");
78 inline const std::string&
type() {
return type_; }
85 virtual void SetGradientCompression(
const std::vector<std::pair<std::string, std::string> >
104 virtual void Init(
const std::vector<int>& keys,
105 const std::vector<NDArray>& values) = 0;
111 virtual void Init(
const std::vector<std::string>& str_keys,
112 const std::vector<NDArray>& values) = 0;
149 virtual void Push(
const std::vector<int>& keys,
150 const std::vector<NDArray>& values,
151 int priority = 0) = 0;
159 virtual void Push(
const std::vector<std::string>& str_keys,
160 const std::vector<NDArray>& values,
161 int priority = 0) = 0;
186 virtual void Pull(
const std::vector<int>& keys,
187 const std::vector<NDArray*>& values,
188 int priority = 0,
bool ignore_sparse =
true) = 0;
196 virtual void Pull(
const std::vector<std::string>& str_keys,
197 const std::vector<NDArray*>& values,
198 int priority = 0,
bool ignore_sparse =
true) = 0;
208 virtual void Broadcast(
const std::vector<int>& vkeys,
209 const std::vector<int>& okeys,
210 const std::vector<NDArray>& values,
211 const std::vector<NDArray*>& outs,
212 int priority = 0) = 0;
222 virtual void Broadcast(
const std::vector<std::string>& str_vkeys,
223 const std::vector<std::string>& str_okeys,
224 const std::vector<NDArray>& values,
225 const std::vector<NDArray*>& outs,
226 int priority = 0) = 0;
236 virtual void PushPull(
const std::vector<int>& vkeys,
237 const std::vector<int>& okeys,
238 const std::vector<NDArray>& values,
239 const std::vector<NDArray*>& outs,
240 int priority = 0) = 0;
250 virtual void PushPull(
const std::vector<std::string>& str_vkeys,
251 const std::vector<std::string>& str_okeys,
252 const std::vector<NDArray>& values,
253 const std::vector<NDArray*>& outs,
254 int priority = 0) = 0;
263 virtual void PullRowSparse(
const std::vector<int>& str_keys,
264 const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
265 int priority = 0) = 0;
275 virtual void PullRowSparse(
const std::vector<std::string>& str_keys,
276 const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
277 int priority = 0) = 0;
282 typedef std::function<void(int, const NDArray&, NDArray*)>
Updater;
286 typedef std::function<void(const std::string&, const NDArray&, NDArray*)>
StrUpdater;
297 CHECK(updater) <<
"invalid updater";
311 CHECK(updater) <<
"invalid updater";
312 str_updater_ = updater;
323 static void InitPSEnv(
const std::unordered_map<std::string, std::string>& envs) {
324 #if MXNET_USE_DIST_KVSTORE 325 ps::Environment::Init(envs);
327 LOG(FATAL) <<
"compile with USE_DIST_KVSTORE=1 to init parameter server's environment";
328 #endif // MXNET_USE_DIST_KVSTORE 337 #if MXNET_USE_DIST_KVSTORE 338 const char* role_str = ps::Environment::Get()->find(
"DMLC_ROLE");
339 return (role_str ==
nullptr) || (!strcmp(role_str,
"worker"));
342 #endif // MXNET_USE_DIST_KVSTORE 351 #if MXNET_USE_DIST_KVSTORE 352 const char* role_str = ps::Environment::Get()->find(
"DMLC_ROLE");
353 return (role_str !=
nullptr) && (!strcmp(role_str,
"server"));
356 #endif // MXNET_USE_DIST_KVSTORE 360 #if MXNET_USE_DIST_KVSTORE 361 if (!IsWorkerNode()) LOG(FATAL) <<
"barrier_before_exit takes effect only on worker nodes";
362 barrier_before_exit_ = barrier_before_exit;
364 LOG(FATAL) <<
"compile with USE_DIST_KVSTORE=1 to enable barrier";
374 #if MXNET_USE_DIST_KVSTORE 375 const char* role_str = ps::Environment::Get()->find(
"DMLC_ROLE");
376 return (role_str !=
nullptr) && (!strcmp(role_str,
"scheduler"));
379 #endif // MXNET_USE_DIST_KVSTORE 440 const std::string& params) {
441 LOG(INFO) <<
"Unable to pass server the profiler command. If you are using " 442 <<
"distributed kvstore, you need to compile with USE_DIST_KVSTORE=1." 443 <<
"If you are training on single machine, then there is no server process" 444 <<
"to profile. Please profile the worker process instead.";
450 typedef std::function<void(int, const std::string&)>
Controller;
465 virtual void RunServer(
const Controller& controller) { }
492 std::atomic<bool> barrier_before_exit_{
true};
496 #endif // MXNET_KVSTORE_H_ distributed key-value store
Definition: kvstore.h:58
std::function< void(int, const NDArray &, NDArray *)> Updater
the prototype of user-defined updater
Definition: kvstore.h:282
namespace of mxnet
Definition: api_registry.h:33
virtual int get_group_size() const
Definition: kvstore.h:395
virtual void set_updater(const StrUpdater &updater)
set an updater with string keys
Definition: kvstore.h:310
Updater updater_
the user-defined updater
Definition: kvstore.h:471
const std::string & type()
return the type
Definition: kvstore.h:78
static bool IsSchedulerNode()
Definition: kvstore.h:373
virtual void Barrier()
global barrier among all worker machines
Definition: kvstore.h:418
static void InitPSEnv(const std::unordered_map< std::string, std::string > &envs)
initalize ps-lite environment variables
Definition: kvstore.h:323
static bool IsWorkerNode()
Definition: kvstore.h:336
virtual ~KVStore()
virtual destructor
Definition: kvstore.h:61
void set_barrier_before_exit(const bool barrier_before_exit)
Definition: kvstore.h:359
virtual void SetServerProfilerCommand(const KVStoreServerProfilerCommand type, const std::string ¶ms)
Sends server profiler commands to all server nodes Only the worker with rank=0 sends the command whic...
Definition: kvstore.h:439
StrUpdater str_updater_
the user-defined updater with string keys
Definition: kvstore.h:476
virtual int get_num_dead_node(int node_id, int timeout=60) const
Definition: kvstore.h:407
std::shared_ptr< kvstore::GradientCompression > gradient_compression_
Gradient compression object starts with GC_NONE mode Used if SetGradientCompression sets the type...
Definition: kvstore.h:487
virtual void RunServer(const Controller &controller)
Run as server (or scheduler)
Definition: kvstore.h:465
std::function< void(const std::string &, const NDArray &, NDArray *)> StrUpdater
the prototype of user-defined updater with string keys
Definition: kvstore.h:286
virtual int get_rank() const
Definition: kvstore.h:388
virtual void SendCommandToServers(int cmd_id, const std::string &cmd_body)
Send a command to all server nodes.
Definition: kvstore.h:431
std::string type_
the kvstore type
Definition: kvstore.h:481
std::function< void(int, const std::string &)> Controller
the prototype of a server controller
Definition: kvstore.h:450
virtual void set_updater(const Updater &updater)
set an updater
Definition: kvstore.h:296
KVStoreServerProfilerCommand
enum to denote types of commands kvstore sends to server regarding profiler kSetConfig sets profiler ...
Definition: kvstore.h:48
static bool IsServerNode()
Definition: kvstore.h:350