6 #ifndef DMLC_PARAMETER_H_ 7 #define DMLC_PARAMETER_H_ 27 #include "./logging.h" 35 struct ParamError :
public dmlc::Error {
40 explicit ParamError(
const std::string &msg)
50 template<
typename ValueType>
51 inline ValueType GetEnv(
const char *key,
52 ValueType default_value);
59 template<
typename ValueType>
60 inline void SetEnv(
const char *key,
68 class FieldAccessEntry;
70 template<
typename DType>
73 template<
typename PType>
74 struct ParamManagerSingleton;
77 enum ParamInitOption {
89 struct ParamFieldInfo {
98 std::string type_info_str;
100 std::string description;
127 template<
typename PType>
140 template<
typename Container>
141 inline void Init(
const Container &kwargs,
142 parameter::ParamInitOption option = parameter::kAllowHidden) {
143 PType::__MANAGER__()->RunInit(static_cast<PType*>(
this),
144 kwargs.begin(), kwargs.end(),
157 template<
typename Container>
158 inline std::vector<std::pair<std::string, std::string> >
159 InitAllowUnknown(
const Container &kwargs) {
160 std::vector<std::pair<std::string, std::string> > unknown;
161 PType::__MANAGER__()->RunInit(static_cast<PType*>(
this),
162 kwargs.begin(), kwargs.end(),
163 &unknown, parameter::kAllowUnknown);
179 template <
typename Container>
180 std::vector<std::pair<std::string, std::string> >
181 UpdateAllowUnknown(Container
const& kwargs,
bool* out_changed =
nullptr) {
182 std::vector<std::pair<std::string, std::string> > unknown;
183 bool changed {
false};
184 changed = PType::__MANAGER__()->RunUpdate(static_cast<PType*>(
this),
185 kwargs.begin(), kwargs.end(),
186 parameter::kAllowUnknown, &unknown,
nullptr);
187 if (out_changed) { *out_changed = changed; }
197 template<
typename Container>
198 inline void UpdateDict(Container *dict)
const {
199 PType::__MANAGER__()->UpdateDict(this->head(), dict);
205 inline std::map<std::string, std::string> __DICT__()
const {
206 std::vector<std::pair<std::string, std::string> > vec
207 = PType::__MANAGER__()->GetDict(this->head());
208 return std::map<std::string, std::string>(vec.begin(), vec.end());
215 writer->
Write(this->__DICT__());
223 std::map<std::string, std::string> kwargs;
224 reader->
Read(&kwargs);
231 inline static std::vector<ParamFieldInfo> __FIELDS__() {
232 return PType::__MANAGER__()->GetFieldInfo();
238 inline static std::string __DOC__() {
239 std::ostringstream os;
240 PType::__MANAGER__()->PrintDocString(os);
251 template<
typename DType>
252 inline parameter::FieldEntry<DType>& DECLARE(
253 parameter::ParamManagerSingleton<PType> *manager,
254 const std::string &key, DType &ref) {
255 parameter::FieldEntry<DType> *e =
256 new parameter::FieldEntry<DType>();
257 e->Init(key, this->head(), ref);
258 manager->manager.AddEntry(key, e);
264 inline PType *head()
const {
265 return static_cast<PType*
>(
const_cast<Parameter<PType>*
>(
this));
289 #define DMLC_DECLARE_PARAMETER(PType) \ 290 static ::dmlc::parameter::ParamManager *__MANAGER__(); \ 291 inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \ 297 #define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName) 304 #define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName) 314 #define DMLC_REGISTER_PARAMETER(PType) \ 315 ::dmlc::parameter::ParamManager *PType::__MANAGER__() { \ 316 static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \ 317 return &inst.manager; \ 319 static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \ 320 __make__ ## PType ## ParamManager__ = \ 321 (*PType::__MANAGER__()) \ 328 namespace parameter {
335 class FieldAccessEntry {
338 : has_default_(false), index_(0) {}
340 virtual ~FieldAccessEntry() {}
346 virtual void SetDefault(
void *head)
const = 0;
352 virtual void Set(
void *head,
const std::string &value)
const = 0;
358 virtual bool Same(
void* head,
const std::string& value)
const = 0;
360 virtual void Check(
void *head)
const {}
365 virtual std::string GetStringValue(
void *head)
const = 0;
370 virtual ParamFieldInfo GetFieldInfo()
const = 0;
382 std::string description_;
386 char* GetRawPtr(
void* head)
const {
387 return reinterpret_cast<char*
>(head) + offset_;
393 virtual void PrintDefaultValueString(std::ostream &os)
const = 0;
395 friend class ParamManager;
406 for (
size_t i = 0; i < entry_.size(); ++i) {
415 inline FieldAccessEntry *Find(
const std::string &key)
const {
416 std::map<std::string, FieldAccessEntry*>::const_iterator it =
417 entry_map_.find(key);
418 if (it == entry_map_.end())
return NULL;
431 template<
typename RandomAccessIterator>
432 inline void RunInit(
void *head,
433 RandomAccessIterator begin,
434 RandomAccessIterator end,
435 std::vector<std::pair<std::string, std::string> > *unknown_args,
436 parameter::ParamInitOption option)
const {
437 std::set<FieldAccessEntry*> selected_args;
438 RunUpdate(head, begin, end, option, unknown_args, &selected_args);
439 for (
auto const& kv : entry_map_) {
440 if (selected_args.find(kv.second) == selected_args.cend()) {
441 kv.second->SetDefault(head);
444 for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin();
445 it != entry_map_.end(); ++it) {
446 if (selected_args.count(it->second) == 0) {
447 it->second->SetDefault(head);
463 template <
typename RandomAccessIterator>
464 bool RunUpdate(
void *head,
465 RandomAccessIterator begin,
466 RandomAccessIterator end,
467 parameter::ParamInitOption option,
468 std::vector<std::pair<std::string, std::string> > *unknown_args,
469 std::set<FieldAccessEntry*>* selected_args =
nullptr)
const {
470 bool changed {
false};
471 for (RandomAccessIterator it = begin; it != end; ++it) {
472 if (FieldAccessEntry *e = Find(it->first)) {
473 if (!e->Same(head, it->second)) {
476 e->Set(head, it->second);
479 selected_args->insert(e);
482 if (unknown_args != NULL) {
483 unknown_args->push_back(*it);
485 if (option != parameter::kAllowUnknown) {
486 if (option == parameter::kAllowHidden &&
487 it->first.length() > 4 &&
488 it->first.find(
"__") == 0 &&
489 it->first.rfind(
"__") == it->first.length()-2) {
492 std::ostringstream os;
493 os <<
"Cannot find argument \'" << it->first <<
"\', Possible Arguments:\n";
494 os <<
"----------------\n";
496 throw dmlc::ParamError(os.str());
509 inline void AddEntry(
const std::string &key, FieldAccessEntry *e) {
510 e->index_ = entry_.size();
512 if (entry_map_.count(key) != 0) {
513 LOG(FATAL) <<
"key " << key <<
" has already been registered in " << name_;
524 inline void AddAlias(
const std::string& field,
const std::string& alias) {
525 if (entry_map_.count(field) == 0) {
526 LOG(FATAL) <<
"key " << field <<
" has not been registered in " << name_;
528 if (entry_map_.count(alias) != 0) {
529 LOG(FATAL) <<
"Alias " << alias <<
" has already been registered in " << name_;
531 entry_map_[alias] = entry_map_[field];
537 inline void set_name(
const std::string &name) {
544 inline std::vector<ParamFieldInfo> GetFieldInfo()
const {
545 std::vector<ParamFieldInfo> ret(entry_.size());
546 for (
size_t i = 0; i < entry_.size(); ++i) {
547 ret[i] = entry_[i]->GetFieldInfo();
555 inline void PrintDocString(std::ostream &os)
const {
556 for (
size_t i = 0; i < entry_.size(); ++i) {
557 ParamFieldInfo info = entry_[i]->GetFieldInfo();
558 os << info.name <<
" : " << info.type_info_str <<
'\n';
559 if (info.description.length() != 0) {
560 os <<
" " << info.description <<
'\n';
570 inline std::vector<std::pair<std::string, std::string> > GetDict(
void * head)
const {
571 std::vector<std::pair<std::string, std::string> > ret;
572 for (std::map<std::string, FieldAccessEntry*>::const_iterator
573 it = entry_map_.begin(); it != entry_map_.end(); ++it) {
574 ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head)));
584 template<
typename Container>
585 inline void UpdateDict(
void * head, Container* dict)
const {
586 for (std::map<std::string, FieldAccessEntry*>::const_iterator
587 it = entry_map_.begin(); it != entry_map_.end(); ++it) {
588 (*dict)[it->first] = it->second->GetStringValue(head);
596 std::vector<FieldAccessEntry*> entry_;
598 std::map<std::string, FieldAccessEntry*> entry_map_;
605 template<
typename PType>
606 struct ParamManagerSingleton {
607 ParamManager manager;
608 explicit ParamManagerSingleton(
const std::string ¶m_name) {
610 manager.set_name(param_name);
611 param.__DECLARE__(
this);
617 template<
typename TEntry,
typename DType>
618 class FieldEntryBase :
public FieldAccessEntry {
621 typedef TEntry EntryType;
623 void Set(
void *head,
const std::string &value)
const override {
624 std::istringstream is(value);
625 is >> this->Get(head);
633 is.setstate(std::ios::failbit);
break;
639 std::ostringstream os;
640 os <<
"Invalid Parameter format for " << key_
641 <<
" expect " << type_ <<
" but value=\'" << value<<
'\'';
642 throw dmlc::ParamError(os.str());
649 bool Same(
void* head, std::string
const& value)
const override {
650 DType old = this->Get(head);
652 std::istringstream is(value);
655 bool is_same = std::equal(
656 reinterpret_cast<char*>(&now), reinterpret_cast<char*>(&now) +
sizeof(now),
657 reinterpret_cast<char*>(&old));
660 std::string GetStringValue(
void *head)
const override {
661 std::ostringstream os;
662 PrintValue(os, this->Get(head));
665 ParamFieldInfo GetFieldInfo()
const override {
667 std::ostringstream os;
672 os <<
',' <<
" optional, default=";
673 PrintDefaultValueString(os);
677 info.type_info_str = os.str();
678 info.description = description_;
682 void SetDefault(
void *head)
const override {
684 std::ostringstream os;
685 os <<
"Required parameter " << key_
686 <<
" of " << type_ <<
" is not presented";
687 throw dmlc::ParamError(os.str());
689 this->Get(head) = default_value_;
693 inline TEntry &
self() {
694 return *(
static_cast<TEntry*
>(
this));
697 inline TEntry &set_default(
const DType &default_value) {
698 default_value_ = default_value;
704 inline TEntry &describe(
const std::string &description) {
705 description_ = description;
710 inline void Init(
const std::string &key,
711 void *head, DType &ref) {
713 if (this->type_.length() == 0) {
714 this->type_ = dmlc::type_name<DType>();
716 this->offset_ = ((
char*)&ref) - ((
char*)head);
721 virtual void PrintValue(std::ostream &os, DType value)
const {
724 void PrintDefaultValueString(std::ostream &os)
const override {
725 PrintValue(os, default_value_);
730 inline DType &Get(
void *head)
const {
731 return *(DType*)this->GetRawPtr(head);
734 DType default_value_;
738 template<
typename TEntry,
typename DType>
739 class FieldEntryNumeric
740 :
public FieldEntryBase<TEntry, DType> {
743 : has_begin_(false), has_end_(false) {}
745 virtual TEntry &set_range(DType begin, DType end) {
746 begin_ = begin; end_ = end;
747 has_begin_ =
true; has_end_ =
true;
751 virtual TEntry &set_lower_bound(DType begin) {
752 begin_ = begin; has_begin_ =
true;
756 virtual void Check(
void *head)
const {
757 FieldEntryBase<TEntry, DType>::Check(head);
758 DType v = this->Get(head);
759 if (has_begin_ && has_end_) {
760 if (v < begin_ || v > end_) {
761 std::ostringstream os;
762 os <<
"value " << v <<
" for Parameter " << this->key_
763 <<
" exceed bound [" << begin_ <<
',' << end_ <<
']' <<
'\n';
764 os << this->key_ <<
": " << this->description_;
765 throw dmlc::ParamError(os.str());
767 }
else if (has_begin_ && v < begin_) {
768 std::ostringstream os;
769 os <<
"value " << v <<
" for Parameter " << this->key_
770 <<
" should be greater equal to " << begin_ <<
'\n';
771 os << this->key_ <<
": " << this->description_;
772 throw dmlc::ParamError(os.str());
773 }
else if (has_end_ && v > end_) {
774 std::ostringstream os;
775 os <<
"value " << v <<
" for Parameter " << this->key_
776 <<
" should be smaller equal to " << end_ <<
'\n';
777 os << this->key_ <<
": " << this->description_;
778 throw dmlc::ParamError(os.str());
784 bool has_begin_, has_end_;
794 template<
typename DType>
796 public IfThenElseType<dmlc::is_arithmetic<DType>::value,
797 FieldEntryNumeric<FieldEntry<DType>, DType>,
798 FieldEntryBase<FieldEntry<DType>, DType> >::Type {
803 class FieldEntry<int>
804 :
public FieldEntryNumeric<FieldEntry<int>, int> {
807 FieldEntry<int>() : is_enum_(
false) {}
809 typedef FieldEntryNumeric<FieldEntry<int>,
int> Parent;
811 virtual void Set(
void *head,
const std::string &value)
const {
813 std::map<std::string, int>::const_iterator it = enum_map_.find(value);
814 std::ostringstream os;
815 if (it == enum_map_.end()) {
816 os <<
"Invalid Input: \'" << value;
817 os <<
"\', valid values are: ";
819 throw dmlc::ParamError(os.str());
822 Parent::Set(head, os.str());
825 Parent::Set(head, value);
828 virtual ParamFieldInfo GetFieldInfo()
const {
831 std::ostringstream os;
836 os <<
',' <<
"optional, default=";
837 PrintDefaultValueString(os);
841 info.type_info_str = os.str();
842 info.description = description_;
845 return Parent::GetFieldInfo();
849 inline FieldEntry<int> &add_enum(
const std::string &key,
int value) {
850 if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
851 enum_back_map_.count(value) != 0) {
852 std::ostringstream os;
853 os <<
"Enum " <<
"(" << key <<
": " << value <<
" exisit!" <<
")\n";
855 for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
856 it != enum_map_.end(); ++it) {
857 os <<
"(" << it->first <<
": " << it->second <<
"), ";
859 throw dmlc::ParamError(os.str());
861 enum_map_[key] = value;
862 enum_back_map_[value] = key;
871 std::map<std::string, int> enum_map_;
873 std::map<int, std::string> enum_back_map_;
875 virtual void PrintDefaultValueString(std::ostream &os)
const {
877 PrintValue(os, default_value_);
881 virtual void PrintValue(std::ostream &os,
int value)
const {
883 CHECK_NE(enum_back_map_.count(value), 0U)
884 <<
"Value not found in enum declared";
885 os << enum_back_map_.at(value);
893 inline void PrintEnums(std::ostream &os)
const {
895 for (std::map<std::string, int>::const_iterator
896 it = enum_map_.begin(); it != enum_map_.end(); ++it) {
897 if (it != enum_map_.begin()) {
900 os <<
"\'" << it->first <<
'\'';
909 class FieldEntry<optional<int> >
910 :
public FieldEntryBase<FieldEntry<optional<int> >, optional<int> > {
913 FieldEntry<optional<int> >() : is_enum_(
false) {}
915 typedef FieldEntryBase<FieldEntry<optional<int> >, optional<int> > Parent;
917 virtual void Set(
void *head,
const std::string &value)
const {
918 if (is_enum_ && value !=
"None") {
919 std::map<std::string, int>::const_iterator it = enum_map_.find(value);
920 std::ostringstream os;
921 if (it == enum_map_.end()) {
922 os <<
"Invalid Input: \'" << value;
923 os <<
"\', valid values are: ";
925 throw dmlc::ParamError(os.str());
928 Parent::Set(head, os.str());
931 Parent::Set(head, value);
934 virtual ParamFieldInfo GetFieldInfo()
const {
937 std::ostringstream os;
942 os <<
',' <<
"optional, default=";
943 PrintDefaultValueString(os);
947 info.type_info_str = os.str();
948 info.description = description_;
951 return Parent::GetFieldInfo();
955 inline FieldEntry<optional<int> > &add_enum(
const std::string &key,
int value) {
956 CHECK_NE(key,
"None") <<
"None is reserved for empty optional<int>";
957 if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
958 enum_back_map_.count(value) != 0) {
959 std::ostringstream os;
960 os <<
"Enum " <<
"(" << key <<
": " << value <<
" exisit!" <<
")\n";
962 for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
963 it != enum_map_.end(); ++it) {
964 os <<
"(" << it->first <<
": " << it->second <<
"), ";
966 throw dmlc::ParamError(os.str());
968 enum_map_[key] = value;
969 enum_back_map_[value] = key;
978 std::map<std::string, int> enum_map_;
980 std::map<int, std::string> enum_back_map_;
982 virtual void PrintDefaultValueString(std::ostream &os)
const {
984 PrintValue(os, default_value_);
988 virtual void PrintValue(std::ostream &os, optional<int> value)
const {
993 CHECK_NE(enum_back_map_.count(value.value()), 0U)
994 <<
"Value not found in enum declared";
995 os << enum_back_map_.at(value.value());
1004 inline void PrintEnums(std::ostream &os)
const {
1006 for (std::map<std::string, int>::const_iterator
1007 it = enum_map_.begin(); it != enum_map_.end(); ++it) {
1009 os <<
"\'" << it->first <<
'\'';
1017 class FieldEntry<
std::string>
1018 :
public FieldEntryBase<FieldEntry<std::string>, std::string> {
1021 typedef FieldEntryBase<FieldEntry<std::string>, std::string> Parent;
1023 virtual void Set(
void *head,
const std::string &value)
const {
1024 this->Get(head) = value;
1027 virtual void PrintDefaultValueString(std::ostream &os)
const {
1028 os <<
'\'' << default_value_ <<
'\'';
1034 class FieldEntry<bool>
1035 :
public FieldEntryBase<FieldEntry<bool>, bool> {
1038 typedef FieldEntryBase<FieldEntry<bool>,
bool> Parent;
1040 virtual void Set(
void *head,
const std::string &value)
const {
1041 std::string lower_case; lower_case.resize(value.length());
1042 std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower);
1043 bool &ref = this->Get(head);
1044 if (lower_case ==
"true") {
1046 }
else if (lower_case ==
"false") {
1048 }
else if (lower_case ==
"1") {
1050 }
else if (lower_case ==
"0") {
1053 std::ostringstream os;
1054 os <<
"Invalid Parameter format for " << key_
1055 <<
" expect " << type_ <<
" but value=\'" << value<<
'\'';
1056 throw dmlc::ParamError(os.str());
1062 virtual void PrintValue(std::ostream &os,
bool value)
const {
1063 os << static_cast<int>(value);
1072 class FieldEntry<float> :
public FieldEntryNumeric<FieldEntry<float>, float> {
1075 typedef FieldEntryNumeric<FieldEntry<float>,
float> Parent;
1077 virtual void Set(
void *head,
const std::string &value)
const {
1081 }
catch (
const std::invalid_argument &) {
1082 std::ostringstream os;
1083 os <<
"Invalid Parameter format for " << key_ <<
" expect " << type_
1084 <<
" but value=\'" << value <<
'\'';
1085 throw dmlc::ParamError(os.str());
1086 }
catch (
const std::out_of_range&) {
1087 std::ostringstream os;
1088 os <<
"Out of range value for " << key_ <<
", value=\'" << value <<
'\'';
1089 throw dmlc::ParamError(os.str());
1091 CHECK_LE(pos, value.length());
1092 if (pos < value.length()) {
1093 std::ostringstream os;
1094 os <<
"Some trailing characters could not be parsed: \'" 1095 << value.substr(pos) <<
"\'";
1096 throw dmlc::ParamError(os.str());
1102 virtual void PrintValue(std::ostream &os,
float value)
const {
1103 os << std::setprecision(std::numeric_limits<float>::max_digits10) << value;
1110 class FieldEntry<double>
1111 :
public FieldEntryNumeric<FieldEntry<double>, double> {
1114 typedef FieldEntryNumeric<FieldEntry<double>,
double> Parent;
1116 virtual void Set(
void *head,
const std::string &value)
const {
1120 }
catch (
const std::invalid_argument &) {
1121 std::ostringstream os;
1122 os <<
"Invalid Parameter format for " << key_ <<
" expect " << type_
1123 <<
" but value=\'" << value <<
'\'';
1124 throw dmlc::ParamError(os.str());
1125 }
catch (
const std::out_of_range&) {
1126 std::ostringstream os;
1127 os <<
"Out of range value for " << key_ <<
", value=\'" << value <<
'\'';
1128 throw dmlc::ParamError(os.str());
1130 CHECK_LE(pos, value.length());
1131 if (pos < value.length()) {
1132 std::ostringstream os;
1133 os <<
"Some trailing characters could not be parsed: \'" 1134 << value.substr(pos) <<
"\'";
1135 throw dmlc::ParamError(os.str());
1141 virtual void PrintValue(std::ostream &os,
double value)
const {
1142 os << std::setprecision(std::numeric_limits<double>::max_digits10) << value;
1145 #endif // DMLC_USE_CXX11 1151 template<
typename ValueType>
1152 inline ValueType GetEnv(
const char *key,
1153 ValueType default_value) {
1154 const char *val = getenv(key);
1158 if (val ==
nullptr || !*val) {
1159 return default_value;
1162 parameter::FieldEntry<ValueType> e;
1163 e.Init(key, &ret, ret);
1169 template<
typename ValueType>
1170 inline void SetEnv(
const char *key,
1172 parameter::FieldEntry<ValueType> e;
1173 e.Init(key, &value, value);
1175 _putenv_s(key, e.GetStringValue(&value).c_str());
1177 setenv(key, e.GetStringValue(&value).c_str(), 1);
1181 #endif // DMLC_PARAMETER_H_ Container to hold optional data.
Definition: optional.h:251
double stod(const std::string &value, size_t *pos=nullptr)
A faster implementation of stod(). See documentation of std::stod() for more information. This function will test for overflow and invalid arguments. TODO: the current version does not support hex number TODO: the current version does not handle long decimals: you may only have up to 19 digits after the decimal point, and you cannot have too many digits before the decimal point either.
Definition: strtonum.h:497
A faster implementation of strtof and strtod.
#define DMLC_SUPPRESS_UBSAN
helper macro to supress Undefined Behavior Sanitizer for a specific function
Definition: base.h:157
Lightweight JSON Reader/Writer that read save into C++ data structs. This includes STL composites and...
Lightweight JSON Reader to read any STL compositions and structs. The user need to know the schema of...
Definition: json.h:44
bool isspace(char c)
Inline implementation of isspace(). Tests whether the given character is a whitespace letter...
Definition: strtonum.h:26
namespace for dmlc
Definition: array_view.h:12
void Write(const ValueType &value)
Write value to json.
float stof(const std::string &value, size_t *pos=nullptr)
A faster implementation of stof(). See documentation of std::stof() for more information. This function will test for overflow and invalid arguments. TODO: the current version does not support hex number TODO: the current version does not handle long decimals: you may only have up to 19 digits after the decimal point, and you cannot have too many digits before the decimal point either.
Definition: strtonum.h:467
void Read(ValueType *out_value)
Read next ValueType.
type traits information header
Lightweight json to write any STL compositions.
Definition: json.h:190