25 #ifndef MXNET_KVSTORE_H_ 26 #define MXNET_KVSTORE_H_ 30 #include <unordered_map> 34 #include "../../src/kvstore/gradient_compression.h" 36 #if MXNET_USE_DIST_KVSTORE 38 #endif // MXNET_USE_DIST_KVSTORE 93 virtual void Init(
const std::vector<int>& keys,
94 const std::vector<NDArray>& values) = 0;
100 virtual void Init(
const std::vector<std::string>& str_keys,
101 const std::vector<NDArray>& values) = 0;
138 virtual void Push(
const std::vector<int>& keys,
139 const std::vector<NDArray>& values,
140 int priority = 0) = 0;
148 virtual void Push(
const std::vector<std::string>& str_keys,
149 const std::vector<NDArray>& values,
150 int priority = 0) = 0;
174 virtual void Pull(
const std::vector<int>& keys,
175 const std::vector<NDArray*>& values,
176 int priority = 0) = 0;
183 virtual void Pull(
const std::vector<std::string>& str_keys,
184 const std::vector<NDArray*>& values,
185 int priority = 0) = 0;
196 const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
197 int priority = 0) = 0;
207 virtual void PullRowSparse(
const std::vector<std::string>& str_keys,
208 const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
209 int priority = 0) = 0;
214 typedef std::function<void(int, const NDArray&, NDArray*)>
Updater;
218 typedef std::function<void(const std::string&, const NDArray&, NDArray*)>
StrUpdater;
229 CHECK(updater) <<
"invalid updater";
242 CHECK(updater) <<
"invalid updater";
254 static void InitPSEnv(
const std::unordered_map<std::string, std::string>& envs) {
255 #if MXNET_USE_DIST_KVSTORE 256 ps::Environment::Init(envs);
258 LOG(FATAL) <<
"compile with USE_DIST_KVSTORE=1 to init parameter server's environment";
259 #endif // MXNET_USE_DIST_KVSTORE 268 #if MXNET_USE_DIST_KVSTORE 269 const char* role_str = ps::Environment::Get()->find(
"DMLC_ROLE");
270 return (role_str ==
nullptr) || (!strcmp(role_str,
"worker"));
273 #endif // MXNET_USE_DIST_KVSTORE 282 #if MXNET_USE_DIST_KVSTORE 283 const char* role_str = ps::Environment::Get()->find(
"DMLC_ROLE");
284 return (role_str !=
nullptr) && (!strcmp(role_str,
"server"));
287 #endif // MXNET_USE_DIST_KVSTORE 291 #if MXNET_USE_DIST_KVSTORE 292 if (!
IsWorkerNode()) LOG(FATAL) <<
"barrier_before_exit takes effect only on worker nodes";
295 LOG(FATAL) <<
"compile with USE_DIST_KVSTORE=1 to enable barrier";
305 #if MXNET_USE_DIST_KVSTORE 306 const char* role_str = ps::Environment::Get()->find(
"DMLC_ROLE");
307 return (role_str !=
nullptr) && (!strcmp(role_str,
"scheduler"));
310 #endif // MXNET_USE_DIST_KVSTORE 367 typedef std::function<void(int, const std::string&)>
Controller;
382 virtual void RunServer(
const Controller& controller) { }
413 #endif // MXNET_KVSTORE_H_ distributed key-value store
Definition: kvstore.h:47
std::function< void(int, const NDArray &, NDArray *)> Updater
the prototype of user-defined updater
Definition: kvstore.h:214
virtual void SetGradientCompression(const std::vector< std::pair< std::string, std::string > > &kwargs)=0
Set parameters to use low-bit compressed gradients.
namespace of mxnet
Definition: base.h:127
virtual int get_rank() const
Definition: kvstore.h:319
static KVStore * Create(const char *type="local")
Factory function to create a new KVStore.
virtual void set_updater(const StrUpdater &updater)
set an updater with string keys
Definition: kvstore.h:241
Updater updater_
the user-defined updater
Definition: kvstore.h:388
virtual void PullRowSparse(const std::vector< int > &str_keys, const std::vector< std::pair< NDArray *, NDArray >> &val_rowids, int priority=0)=0
pull a list of key-value pairs from the store. The NDArray pulled back will be in row_sparse storage ...
const std::string & type()
return the type
Definition: kvstore.h:67
virtual void Pull(const std::vector< int > &keys, const std::vector< NDArray * > &values, int priority=0)=0
pull a list of key-value pairs from the store
static bool IsSchedulerNode()
Definition: kvstore.h:304
virtual void Barrier()
global barrier among all worker machines
Definition: kvstore.h:349
static void InitPSEnv(const std::unordered_map< std::string, std::string > &envs)
initalize ps-lite environment variables
Definition: kvstore.h:254
virtual void Init(const std::vector< int > &keys, const std::vector< NDArray > &values)=0
Initialize a list of key-value pair to the store.
static bool IsWorkerNode()
Definition: kvstore.h:267
virtual ~KVStore()
virtual destructor
Definition: kvstore.h:50
void set_barrier_before_exit(const bool barrier_before_exit)
Definition: kvstore.h:290
StrUpdater str_updater_
the user-defined updater with string keys
Definition: kvstore.h:393
virtual int get_num_dead_node(int node_id, int timeout=60) const
Definition: kvstore.h:338
std::shared_ptr< kvstore::GradientCompression > gradient_compression_
Gradient compression object starts with GC_NONE mode Used if SetGradientCompression sets the type...
Definition: kvstore.h:404
virtual void RunServer(const Controller &controller)
Run as server (or scheduler)
Definition: kvstore.h:382
std::function< void(const std::string &, const NDArray &, NDArray *)> StrUpdater
the prototype of user-defined updater with string keys
Definition: kvstore.h:218
virtual void Push(const std::vector< int > &keys, const std::vector< NDArray > &values, int priority=0)=0
push a list of key-value pairs into the store
virtual void SendCommandToServers(int cmd_id, const std::string &cmd_body)
Send a command to all server nodes.
Definition: kvstore.h:362
std::string type_
the kvstore type
Definition: kvstore.h:398
std::function< void(int, const std::string &)> Controller
the prototype of a server controller
Definition: kvstore.h:367
virtual void set_updater(const Updater &updater)
set an updater
Definition: kvstore.h:228
virtual int get_group_size() const
Definition: kvstore.h:326
std::atomic< bool > barrier_before_exit_
whether to do barrier when finalize
Definition: kvstore.h:409
static bool IsServerNode()
Definition: kvstore.h:281