mxnet
parameter.h
Go to the documentation of this file.
1 
6 #ifndef DMLC_PARAMETER_H_
7 #define DMLC_PARAMETER_H_
8 
9 #include <cstddef>
10 #include <cstdlib>
11 #include <cmath>
12 #include <sstream>
13 #include <limits>
14 #include <map>
15 #include <set>
16 #include <typeinfo>
17 #include <string>
18 #include <vector>
19 #include <algorithm>
20 #include <utility>
21 #include <stdexcept>
22 #include <iostream>
23 #include <iomanip>
24 #include <cerrno>
25 #include "./base.h"
26 #include "./json.h"
27 #include "./logging.h"
28 #include "./type_traits.h"
29 #include "./optional.h"
30 #include "./strtonum.h"
31 
32 namespace dmlc {
33 // this file is backward compatible with non-c++11
35 struct ParamError : public dmlc::Error {
40  explicit ParamError(const std::string &msg)
41  : dmlc::Error(msg) {}
42 };
43 
50 template<typename ValueType>
51 inline ValueType GetEnv(const char *key,
52  ValueType default_value);
59 template<typename ValueType>
60 inline void SetEnv(const char *key,
61  ValueType value);
62 
64 namespace parameter {
65 // forward declare ParamManager
66 class ParamManager;
67 // forward declare FieldAccessEntry
68 class FieldAccessEntry;
69 // forward declare FieldEntry
70 template<typename DType>
71 class FieldEntry;
72 // forward declare ParamManagerSingleton
73 template<typename PType>
74 struct ParamManagerSingleton;
75 
77 enum ParamInitOption {
79  kAllowUnknown,
81  kAllMatch,
83  kAllowHidden
84 };
85 } // namespace parameter
89 struct ParamFieldInfo {
91  std::string name;
93  std::string type;
98  std::string type_info_str;
100  std::string description;
101 };
102 
127 template<typename PType>
128 struct Parameter {
129  public:
140  template<typename Container>
141  inline void Init(const Container &kwargs,
142  parameter::ParamInitOption option = parameter::kAllowHidden) {
143  PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
144  kwargs.begin(), kwargs.end(),
145  NULL,
146  option);
147  }
157  template<typename Container>
158  inline std::vector<std::pair<std::string, std::string> >
159  InitAllowUnknown(const Container &kwargs) {
160  std::vector<std::pair<std::string, std::string> > unknown;
161  PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
162  kwargs.begin(), kwargs.end(),
163  &unknown, parameter::kAllowUnknown);
164  return unknown;
165  }
166 
179  template <typename Container>
180  std::vector<std::pair<std::string, std::string> >
181  UpdateAllowUnknown(Container const& kwargs, bool* out_changed = nullptr) {
182  std::vector<std::pair<std::string, std::string> > unknown;
183  bool changed {false};
184  changed = PType::__MANAGER__()->RunUpdate(static_cast<PType*>(this),
185  kwargs.begin(), kwargs.end(),
186  parameter::kAllowUnknown, &unknown, nullptr);
187  if (out_changed) { *out_changed = changed; }
188  return unknown;
189  }
190 
197  template<typename Container>
198  inline void UpdateDict(Container *dict) const {
199  PType::__MANAGER__()->UpdateDict(this->head(), dict);
200  }
205  inline std::map<std::string, std::string> __DICT__() const {
206  std::vector<std::pair<std::string, std::string> > vec
207  = PType::__MANAGER__()->GetDict(this->head());
208  return std::map<std::string, std::string>(vec.begin(), vec.end());
209  }
214  inline void Save(dmlc::JSONWriter *writer) const {
215  writer->Write(this->__DICT__());
216  }
222  inline void Load(dmlc::JSONReader *reader) {
223  std::map<std::string, std::string> kwargs;
224  reader->Read(&kwargs);
225  this->Init(kwargs);
226  }
231  inline static std::vector<ParamFieldInfo> __FIELDS__() {
232  return PType::__MANAGER__()->GetFieldInfo();
233  }
238  inline static std::string __DOC__() {
239  std::ostringstream os;
240  PType::__MANAGER__()->PrintDocString(os);
241  return os.str();
242  }
243 
244  protected:
251  template<typename DType>
252  inline parameter::FieldEntry<DType>& DECLARE(
253  parameter::ParamManagerSingleton<PType> *manager,
254  const std::string &key, DType &ref) { // NOLINT(*)
255  parameter::FieldEntry<DType> *e =
256  new parameter::FieldEntry<DType>();
257  e->Init(key, this->head(), ref);
258  manager->manager.AddEntry(key, e);
259  return *e;
260  }
261 
262  private:
264  inline PType *head() const {
265  return static_cast<PType*>(const_cast<Parameter<PType>*>(this));
266  }
267 };
268 
270 
289 #define DMLC_DECLARE_PARAMETER(PType) \
290  static ::dmlc::parameter::ParamManager *__MANAGER__(); \
291  inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \
292 
293 
297 #define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName)
298 
304 #define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName)
305 
314 #define DMLC_REGISTER_PARAMETER(PType) \
315  ::dmlc::parameter::ParamManager *PType::__MANAGER__() { \
316  static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
317  return &inst.manager; \
318  } \
319  static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
320  __make__ ## PType ## ParamManager__ = \
321  (*PType::__MANAGER__()) \
322 
323 
328 namespace parameter {
335 class FieldAccessEntry {
336  public:
337  FieldAccessEntry()
338  : has_default_(false), index_(0) {}
340  virtual ~FieldAccessEntry() {}
346  virtual void SetDefault(void *head) const = 0;
352  virtual void Set(void *head, const std::string &value) const = 0;
358  virtual bool Same(void* head, const std::string& value) const = 0;
359  // check if value is OK
360  virtual void Check(void *head) const {}
365  virtual std::string GetStringValue(void *head) const = 0;
370  virtual ParamFieldInfo GetFieldInfo() const = 0;
371 
372  protected:
374  bool has_default_;
376  size_t index_;
378  std::string key_;
380  std::string type_;
382  std::string description_;
383  // internal offset of the field
384  ptrdiff_t offset_;
386  char* GetRawPtr(void* head) const {
387  return reinterpret_cast<char*>(head) + offset_;
388  }
393  virtual void PrintDefaultValueString(std::ostream &os) const = 0; // NOLINT(*)
394  // allow ParamManager to modify self
395  friend class ParamManager;
396 };
397 
402 class ParamManager {
403  public:
405  ~ParamManager() {
406  for (size_t i = 0; i < entry_.size(); ++i) {
407  delete entry_[i];
408  }
409  }
415  inline FieldAccessEntry *Find(const std::string &key) const {
416  std::map<std::string, FieldAccessEntry*>::const_iterator it =
417  entry_map_.find(key);
418  if (it == entry_map_.end()) return NULL;
419  return it->second;
420  }
431  template<typename RandomAccessIterator>
432  inline void RunInit(void *head,
433  RandomAccessIterator begin,
434  RandomAccessIterator end,
435  std::vector<std::pair<std::string, std::string> > *unknown_args,
436  parameter::ParamInitOption option) const {
437  std::set<FieldAccessEntry*> selected_args;
438  RunUpdate(head, begin, end, option, unknown_args, &selected_args);
439  for (auto const& kv : entry_map_) {
440  if (selected_args.find(kv.second) == selected_args.cend()) {
441  kv.second->SetDefault(head);
442  }
443  }
444  for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin();
445  it != entry_map_.end(); ++it) {
446  if (selected_args.count(it->second) == 0) {
447  it->second->SetDefault(head);
448  }
449  }
450  }
463  template <typename RandomAccessIterator>
464  bool RunUpdate(void *head,
465  RandomAccessIterator begin,
466  RandomAccessIterator end,
467  parameter::ParamInitOption option,
468  std::vector<std::pair<std::string, std::string> > *unknown_args,
469  std::set<FieldAccessEntry*>* selected_args = nullptr) const {
470  bool changed {false};
471  for (RandomAccessIterator it = begin; it != end; ++it) {
472  if (FieldAccessEntry *e = Find(it->first)) {
473  if (!e->Same(head, it->second)) {
474  changed = true;
475  }
476  e->Set(head, it->second);
477  e->Check(head);
478  if (selected_args) {
479  selected_args->insert(e);
480  }
481  } else {
482  if (unknown_args != NULL) {
483  unknown_args->push_back(*it);
484  } else {
485  if (option != parameter::kAllowUnknown) {
486  if (option == parameter::kAllowHidden &&
487  it->first.length() > 4 &&
488  it->first.find("__") == 0 &&
489  it->first.rfind("__") == it->first.length()-2) {
490  continue;
491  }
492  std::ostringstream os;
493  os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
494  os << "----------------\n";
495  PrintDocString(os);
496  throw dmlc::ParamError(os.str());
497  }
498  }
499  }
500  }
501  return changed;
502  }
509  inline void AddEntry(const std::string &key, FieldAccessEntry *e) {
510  e->index_ = entry_.size();
511  // TODO(bing) better error message
512  if (entry_map_.count(key) != 0) {
513  LOG(FATAL) << "key " << key << " has already been registered in " << name_;
514  }
515  entry_.push_back(e);
516  entry_map_[key] = e;
517  }
524  inline void AddAlias(const std::string& field, const std::string& alias) {
525  if (entry_map_.count(field) == 0) {
526  LOG(FATAL) << "key " << field << " has not been registered in " << name_;
527  }
528  if (entry_map_.count(alias) != 0) {
529  LOG(FATAL) << "Alias " << alias << " has already been registered in " << name_;
530  }
531  entry_map_[alias] = entry_map_[field];
532  }
537  inline void set_name(const std::string &name) {
538  name_ = name;
539  }
544  inline std::vector<ParamFieldInfo> GetFieldInfo() const {
545  std::vector<ParamFieldInfo> ret(entry_.size());
546  for (size_t i = 0; i < entry_.size(); ++i) {
547  ret[i] = entry_[i]->GetFieldInfo();
548  }
549  return ret;
550  }
555  inline void PrintDocString(std::ostream &os) const { // NOLINT(*)
556  for (size_t i = 0; i < entry_.size(); ++i) {
557  ParamFieldInfo info = entry_[i]->GetFieldInfo();
558  os << info.name << " : " << info.type_info_str << '\n';
559  if (info.description.length() != 0) {
560  os << " " << info.description << '\n';
561  }
562  }
563  }
570  inline std::vector<std::pair<std::string, std::string> > GetDict(void * head) const {
571  std::vector<std::pair<std::string, std::string> > ret;
572  for (std::map<std::string, FieldAccessEntry*>::const_iterator
573  it = entry_map_.begin(); it != entry_map_.end(); ++it) {
574  ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head)));
575  }
576  return ret;
577  }
584  template<typename Container>
585  inline void UpdateDict(void * head, Container* dict) const {
586  for (std::map<std::string, FieldAccessEntry*>::const_iterator
587  it = entry_map_.begin(); it != entry_map_.end(); ++it) {
588  (*dict)[it->first] = it->second->GetStringValue(head);
589  }
590  }
591 
592  private:
594  std::string name_;
596  std::vector<FieldAccessEntry*> entry_;
598  std::map<std::string, FieldAccessEntry*> entry_map_;
599 };
600 
602 
603 // The following piece of code will be template heavy and less documented
604 // singleton parameter manager for certain type, used for initialization
605 template<typename PType>
606 struct ParamManagerSingleton {
607  ParamManager manager;
608  explicit ParamManagerSingleton(const std::string &param_name) {
609  PType param;
610  manager.set_name(param_name);
611  param.__DECLARE__(this);
612  }
613 };
614 
615 // Base class of FieldEntry
616 // implement set_default
617 template<typename TEntry, typename DType>
618 class FieldEntryBase : public FieldAccessEntry {
619  public:
620  // entry type
621  typedef TEntry EntryType;
622  // implement set value
623  void Set(void *head, const std::string &value) const override {
624  std::istringstream is(value);
625  is >> this->Get(head);
626  if (!is.fail()) {
627  while (!is.eof()) {
628  int ch = is.get();
629  if (ch == EOF) {
630  is.clear(); break;
631  }
632  if (!isspace(ch)) {
633  is.setstate(std::ios::failbit); break;
634  }
635  }
636  }
637 
638  if (is.fail()) {
639  std::ostringstream os;
640  os << "Invalid Parameter format for " << key_
641  << " expect " << type_ << " but value=\'" << value<< '\'';
642  throw dmlc::ParamError(os.str());
643  }
644  }
645 
646  // Don't check this function for Undefined Behavior (UB), as the function
647  // reads from a possibly uninitialized field
649  bool Same(void* head, std::string const& value) const override {
650  DType old = this->Get(head);
651  DType now;
652  std::istringstream is(value);
653  is >> now;
654  // don't require = operator
655  bool is_same = std::equal(
656  reinterpret_cast<char*>(&now), reinterpret_cast<char*>(&now) + sizeof(now),
657  reinterpret_cast<char*>(&old));
658  return is_same;
659  }
660  std::string GetStringValue(void *head) const override {
661  std::ostringstream os;
662  PrintValue(os, this->Get(head));
663  return os.str();
664  }
665  ParamFieldInfo GetFieldInfo() const override {
666  ParamFieldInfo info;
667  std::ostringstream os;
668  info.name = key_;
669  info.type = type_;
670  os << type_;
671  if (has_default_) {
672  os << ',' << " optional, default=";
673  PrintDefaultValueString(os);
674  } else {
675  os << ", required";
676  }
677  info.type_info_str = os.str();
678  info.description = description_;
679  return info;
680  }
681  // implement set head to default value
682  void SetDefault(void *head) const override {
683  if (!has_default_) {
684  std::ostringstream os;
685  os << "Required parameter " << key_
686  << " of " << type_ << " is not presented";
687  throw dmlc::ParamError(os.str());
688  } else {
689  this->Get(head) = default_value_;
690  }
691  }
692  // return reference of self as derived type
693  inline TEntry &self() {
694  return *(static_cast<TEntry*>(this));
695  }
696  // implement set_default
697  inline TEntry &set_default(const DType &default_value) {
698  default_value_ = default_value;
699  has_default_ = true;
700  // return self to allow chaining
701  return this->self();
702  }
703  // implement describe
704  inline TEntry &describe(const std::string &description) {
705  description_ = description;
706  // return self to allow chaining
707  return this->self();
708  }
709  // initialization function
710  inline void Init(const std::string &key,
711  void *head, DType &ref) { // NOLINT(*)
712  this->key_ = key;
713  if (this->type_.length() == 0) {
714  this->type_ = dmlc::type_name<DType>();
715  }
716  this->offset_ = ((char*)&ref) - ((char*)head); // NOLINT(*)
717  }
718 
719  protected:
720  // print the value
721  virtual void PrintValue(std::ostream &os, DType value) const { // NOLINT(*)
722  os << value;
723  }
724  void PrintDefaultValueString(std::ostream &os) const override { // NOLINT(*)
725  PrintValue(os, default_value_);
726  }
727  // get the internal representation of parameter
728  // for example if this entry corresponds field param.learning_rate
729  // then Get(&param) will return reference to param.learning_rate
730  inline DType &Get(void *head) const {
731  return *(DType*)this->GetRawPtr(head); // NOLINT(*)
732  }
733  // default value of field
734  DType default_value_;
735 };
736 
737 // parameter base for numeric types that have range
738 template<typename TEntry, typename DType>
739 class FieldEntryNumeric
740  : public FieldEntryBase<TEntry, DType> {
741  public:
742  FieldEntryNumeric()
743  : has_begin_(false), has_end_(false) {}
744  // implement set_range
745  virtual TEntry &set_range(DType begin, DType end) {
746  begin_ = begin; end_ = end;
747  has_begin_ = true; has_end_ = true;
748  return this->self();
749  }
750  // implement set_range
751  virtual TEntry &set_lower_bound(DType begin) {
752  begin_ = begin; has_begin_ = true;
753  return this->self();
754  }
755  // consistency check for numeric ranges
756  virtual void Check(void *head) const {
757  FieldEntryBase<TEntry, DType>::Check(head);
758  DType v = this->Get(head);
759  if (has_begin_ && has_end_) {
760  if (v < begin_ || v > end_) {
761  std::ostringstream os;
762  os << "value " << v << " for Parameter " << this->key_
763  << " exceed bound [" << begin_ << ',' << end_ <<']' << '\n';
764  os << this->key_ << ": " << this->description_;
765  throw dmlc::ParamError(os.str());
766  }
767  } else if (has_begin_ && v < begin_) {
768  std::ostringstream os;
769  os << "value " << v << " for Parameter " << this->key_
770  << " should be greater equal to " << begin_ << '\n';
771  os << this->key_ << ": " << this->description_;
772  throw dmlc::ParamError(os.str());
773  } else if (has_end_ && v > end_) {
774  std::ostringstream os;
775  os << "value " << v << " for Parameter " << this->key_
776  << " should be smaller equal to " << end_ << '\n';
777  os << this->key_ << ": " << this->description_;
778  throw dmlc::ParamError(os.str());
779  }
780  }
781 
782  protected:
783  // whether it have begin and end range
784  bool has_begin_, has_end_;
785  // data bound
786  DType begin_, end_;
787 };
788 
794 template<typename DType>
795 class FieldEntry :
796  public IfThenElseType<dmlc::is_arithmetic<DType>::value,
797  FieldEntryNumeric<FieldEntry<DType>, DType>,
798  FieldEntryBase<FieldEntry<DType>, DType> >::Type {
799 };
800 
801 // specialize define for int(enum)
802 template<>
803 class FieldEntry<int>
804  : public FieldEntryNumeric<FieldEntry<int>, int> {
805  public:
806  // construct
807  FieldEntry<int>() : is_enum_(false) {}
808  // parent
809  typedef FieldEntryNumeric<FieldEntry<int>, int> Parent;
810  // override set
811  virtual void Set(void *head, const std::string &value) const {
812  if (is_enum_) {
813  std::map<std::string, int>::const_iterator it = enum_map_.find(value);
814  std::ostringstream os;
815  if (it == enum_map_.end()) {
816  os << "Invalid Input: \'" << value;
817  os << "\', valid values are: ";
818  PrintEnums(os);
819  throw dmlc::ParamError(os.str());
820  } else {
821  os << it->second;
822  Parent::Set(head, os.str());
823  }
824  } else {
825  Parent::Set(head, value);
826  }
827  }
828  virtual ParamFieldInfo GetFieldInfo() const {
829  if (is_enum_) {
830  ParamFieldInfo info;
831  std::ostringstream os;
832  info.name = key_;
833  info.type = type_;
834  PrintEnums(os);
835  if (has_default_) {
836  os << ',' << "optional, default=";
837  PrintDefaultValueString(os);
838  } else {
839  os << ", required";
840  }
841  info.type_info_str = os.str();
842  info.description = description_;
843  return info;
844  } else {
845  return Parent::GetFieldInfo();
846  }
847  }
848  // add enum
849  inline FieldEntry<int> &add_enum(const std::string &key, int value) {
850  if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
851  enum_back_map_.count(value) != 0) {
852  std::ostringstream os;
853  os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n";
854  os << "Enums: ";
855  for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
856  it != enum_map_.end(); ++it) {
857  os << "(" << it->first << ": " << it->second << "), ";
858  }
859  throw dmlc::ParamError(os.str());
860  }
861  enum_map_[key] = value;
862  enum_back_map_[value] = key;
863  is_enum_ = true;
864  return this->self();
865  }
866 
867  protected:
868  // enum flag
869  bool is_enum_;
870  // enum map
871  std::map<std::string, int> enum_map_;
872  // enum map
873  std::map<int, std::string> enum_back_map_;
874  // override print behavior
875  virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
876  os << '\'';
877  PrintValue(os, default_value_);
878  os << '\'';
879  }
880  // override print default
881  virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*)
882  if (is_enum_) {
883  CHECK_NE(enum_back_map_.count(value), 0U)
884  << "Value not found in enum declared";
885  os << enum_back_map_.at(value);
886  } else {
887  os << value;
888  }
889  }
890 
891 
892  private:
893  inline void PrintEnums(std::ostream &os) const { // NOLINT(*)
894  os << '{';
895  for (std::map<std::string, int>::const_iterator
896  it = enum_map_.begin(); it != enum_map_.end(); ++it) {
897  if (it != enum_map_.begin()) {
898  os << ", ";
899  }
900  os << "\'" << it->first << '\'';
901  }
902  os << '}';
903  }
904 };
905 
906 
907 // specialize define for optional<int>(enum)
908 template<>
909 class FieldEntry<optional<int> >
910  : public FieldEntryBase<FieldEntry<optional<int> >, optional<int> > {
911  public:
912  // construct
913  FieldEntry<optional<int> >() : is_enum_(false) {}
914  // parent
915  typedef FieldEntryBase<FieldEntry<optional<int> >, optional<int> > Parent;
916  // override set
917  virtual void Set(void *head, const std::string &value) const {
918  if (is_enum_ && value != "None") {
919  std::map<std::string, int>::const_iterator it = enum_map_.find(value);
920  std::ostringstream os;
921  if (it == enum_map_.end()) {
922  os << "Invalid Input: \'" << value;
923  os << "\', valid values are: ";
924  PrintEnums(os);
925  throw dmlc::ParamError(os.str());
926  } else {
927  os << it->second;
928  Parent::Set(head, os.str());
929  }
930  } else {
931  Parent::Set(head, value);
932  }
933  }
934  virtual ParamFieldInfo GetFieldInfo() const {
935  if (is_enum_) {
936  ParamFieldInfo info;
937  std::ostringstream os;
938  info.name = key_;
939  info.type = type_;
940  PrintEnums(os);
941  if (has_default_) {
942  os << ',' << "optional, default=";
943  PrintDefaultValueString(os);
944  } else {
945  os << ", required";
946  }
947  info.type_info_str = os.str();
948  info.description = description_;
949  return info;
950  } else {
951  return Parent::GetFieldInfo();
952  }
953  }
954  // add enum
955  inline FieldEntry<optional<int> > &add_enum(const std::string &key, int value) {
956  CHECK_NE(key, "None") << "None is reserved for empty optional<int>";
957  if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
958  enum_back_map_.count(value) != 0) {
959  std::ostringstream os;
960  os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n";
961  os << "Enums: ";
962  for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
963  it != enum_map_.end(); ++it) {
964  os << "(" << it->first << ": " << it->second << "), ";
965  }
966  throw dmlc::ParamError(os.str());
967  }
968  enum_map_[key] = value;
969  enum_back_map_[value] = key;
970  is_enum_ = true;
971  return this->self();
972  }
973 
974  protected:
975  // enum flag
976  bool is_enum_;
977  // enum map
978  std::map<std::string, int> enum_map_;
979  // enum map
980  std::map<int, std::string> enum_back_map_;
981  // override print behavior
982  virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
983  os << '\'';
984  PrintValue(os, default_value_);
985  os << '\'';
986  }
987  // override print default
988  virtual void PrintValue(std::ostream &os, optional<int> value) const { // NOLINT(*)
989  if (is_enum_) {
990  if (!value) {
991  os << "None";
992  } else {
993  CHECK_NE(enum_back_map_.count(value.value()), 0U)
994  << "Value not found in enum declared";
995  os << enum_back_map_.at(value.value());
996  }
997  } else {
998  os << value;
999  }
1000  }
1001 
1002 
1003  private:
1004  inline void PrintEnums(std::ostream &os) const { // NOLINT(*)
1005  os << "{None";
1006  for (std::map<std::string, int>::const_iterator
1007  it = enum_map_.begin(); it != enum_map_.end(); ++it) {
1008  os << ", ";
1009  os << "\'" << it->first << '\'';
1010  }
1011  os << '}';
1012  }
1013 };
1014 
1015 // specialize define for string
1016 template<>
1017 class FieldEntry<std::string>
1018  : public FieldEntryBase<FieldEntry<std::string>, std::string> {
1019  public:
1020  // parent class
1021  typedef FieldEntryBase<FieldEntry<std::string>, std::string> Parent;
1022  // override set
1023  virtual void Set(void *head, const std::string &value) const {
1024  this->Get(head) = value;
1025  }
1026  // override print default
1027  virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
1028  os << '\'' << default_value_ << '\'';
1029  }
1030 };
1031 
1032 // specialize define for bool
1033 template<>
1034 class FieldEntry<bool>
1035  : public FieldEntryBase<FieldEntry<bool>, bool> {
1036  public:
1037  // parent class
1038  typedef FieldEntryBase<FieldEntry<bool>, bool> Parent;
1039  // override set
1040  virtual void Set(void *head, const std::string &value) const {
1041  std::string lower_case; lower_case.resize(value.length());
1042  std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower);
1043  bool &ref = this->Get(head);
1044  if (lower_case == "true") {
1045  ref = true;
1046  } else if (lower_case == "false") {
1047  ref = false;
1048  } else if (lower_case == "1") {
1049  ref = true;
1050  } else if (lower_case == "0") {
1051  ref = false;
1052  } else {
1053  std::ostringstream os;
1054  os << "Invalid Parameter format for " << key_
1055  << " expect " << type_ << " but value=\'" << value<< '\'';
1056  throw dmlc::ParamError(os.str());
1057  }
1058  }
1059 
1060  protected:
1061  // print default string
1062  virtual void PrintValue(std::ostream &os, bool value) const { // NOLINT(*)
1063  os << static_cast<int>(value);
1064  }
1065 };
1066 
1067 
1068 // specialize define for float. Uses stof for platform independent handling of
1069 // INF, -INF, NAN, etc.
1070 #if DMLC_USE_CXX11
1071 template <>
1072 class FieldEntry<float> : public FieldEntryNumeric<FieldEntry<float>, float> {
1073  public:
1074  // parent
1075  typedef FieldEntryNumeric<FieldEntry<float>, float> Parent;
1076  // override set
1077  virtual void Set(void *head, const std::string &value) const {
1078  size_t pos = 0; // number of characters processed by dmlc::stof()
1079  try {
1080  this->Get(head) = dmlc::stof(value, &pos);
1081  } catch (const std::invalid_argument &) {
1082  std::ostringstream os;
1083  os << "Invalid Parameter format for " << key_ << " expect " << type_
1084  << " but value=\'" << value << '\'';
1085  throw dmlc::ParamError(os.str());
1086  } catch (const std::out_of_range&) {
1087  std::ostringstream os;
1088  os << "Out of range value for " << key_ << ", value=\'" << value << '\'';
1089  throw dmlc::ParamError(os.str());
1090  }
1091  CHECK_LE(pos, value.length()); // just in case
1092  if (pos < value.length()) {
1093  std::ostringstream os;
1094  os << "Some trailing characters could not be parsed: \'"
1095  << value.substr(pos) << "\'";
1096  throw dmlc::ParamError(os.str());
1097  }
1098  }
1099 
1100  protected:
1101  // print the value
1102  virtual void PrintValue(std::ostream &os, float value) const { // NOLINT(*)
1103  os << std::setprecision(std::numeric_limits<float>::max_digits10) << value;
1104  }
1105 };
1106 
1107 // specialize define for double. Uses stod for platform independent handling of
1108 // INF, -INF, NAN, etc.
1109 template <>
1110 class FieldEntry<double>
1111  : public FieldEntryNumeric<FieldEntry<double>, double> {
1112  public:
1113  // parent
1114  typedef FieldEntryNumeric<FieldEntry<double>, double> Parent;
1115  // override set
1116  virtual void Set(void *head, const std::string &value) const {
1117  size_t pos = 0; // number of characters processed by dmlc::stod()
1118  try {
1119  this->Get(head) = dmlc::stod(value, &pos);
1120  } catch (const std::invalid_argument &) {
1121  std::ostringstream os;
1122  os << "Invalid Parameter format for " << key_ << " expect " << type_
1123  << " but value=\'" << value << '\'';
1124  throw dmlc::ParamError(os.str());
1125  } catch (const std::out_of_range&) {
1126  std::ostringstream os;
1127  os << "Out of range value for " << key_ << ", value=\'" << value << '\'';
1128  throw dmlc::ParamError(os.str());
1129  }
1130  CHECK_LE(pos, value.length()); // just in case
1131  if (pos < value.length()) {
1132  std::ostringstream os;
1133  os << "Some trailing characters could not be parsed: \'"
1134  << value.substr(pos) << "\'";
1135  throw dmlc::ParamError(os.str());
1136  }
1137  }
1138 
1139  protected:
1140  // print the value
1141  virtual void PrintValue(std::ostream &os, double value) const { // NOLINT(*)
1142  os << std::setprecision(std::numeric_limits<double>::max_digits10) << value;
1143  }
1144 };
1145 #endif // DMLC_USE_CXX11
1146 
1147 } // namespace parameter
1149 
1150 // implement GetEnv
1151 template<typename ValueType>
1152 inline ValueType GetEnv(const char *key,
1153  ValueType default_value) {
1154  const char *val = getenv(key);
1155  // On some implementations, if the var is set to a blank string (i.e. "FOO="), then
1156  // a blank string will be returned instead of NULL. In order to be consistent, if
1157  // the environment var is a blank string, then also behave as if a null was returned.
1158  if (val == nullptr || !*val) {
1159  return default_value;
1160  }
1161  ValueType ret;
1162  parameter::FieldEntry<ValueType> e;
1163  e.Init(key, &ret, ret);
1164  e.Set(&ret, val);
1165  return ret;
1166 }
1167 
1168 // implement SetEnv
1169 template<typename ValueType>
1170 inline void SetEnv(const char *key,
1171  ValueType value) {
1172  parameter::FieldEntry<ValueType> e;
1173  e.Init(key, &value, value);
1174 #ifdef _WIN32
1175  _putenv_s(key, e.GetStringValue(&value).c_str());
1176 #else
1177  setenv(key, e.GetStringValue(&value).c_str(), 1);
1178 #endif // _WIN32
1179 }
1180 } // namespace dmlc
1181 #endif // DMLC_PARAMETER_H_
Container to hold optional data.
Definition: optional.h:251
double stod(const std::string &value, size_t *pos=nullptr)
A faster implementation of stod(). See documentation of std::stod() for more information. This function will test for overflow and invalid arguments. TODO: the current version does not support hex number TODO: the current version does not handle long decimals: you may only have up to 19 digits after the decimal point, and you cannot have too many digits before the decimal point either.
Definition: strtonum.h:497
A faster implementation of strtof and strtod.
#define DMLC_SUPPRESS_UBSAN
helper macro to supress Undefined Behavior Sanitizer for a specific function
Definition: base.h:157
Lightweight JSON Reader/Writer that read save into C++ data structs. This includes STL composites and...
Lightweight JSON Reader to read any STL compositions and structs. The user need to know the schema of...
Definition: json.h:44
bool isspace(char c)
Inline implementation of isspace(). Tests whether the given character is a whitespace letter...
Definition: strtonum.h:26
namespace for dmlc
Definition: array_view.h:12
void Write(const ValueType &value)
Write value to json.
float stof(const std::string &value, size_t *pos=nullptr)
A faster implementation of stof(). See documentation of std::stof() for more information. This function will test for overflow and invalid arguments. TODO: the current version does not support hex number TODO: the current version does not handle long decimals: you may only have up to 19 digits after the decimal point, and you cannot have too many digits before the decimal point either.
Definition: strtonum.h:467
void Read(ValueType *out_value)
Read next ValueType.
type traits information header
Lightweight json to write any STL compositions.
Definition: json.h:190