mxnet
model.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MXNET_CPP_MODEL_H_
28 #define MXNET_CPP_MODEL_H_
29 
30 #include <string>
31 #include <vector>
32 #include "mxnet-cpp/base.h"
33 #include "mxnet-cpp/symbol.h"
34 #include "mxnet-cpp/ndarray.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 
41  std::vector<Context> ctx = {Context::cpu()};
42  int num_epoch = 0;
43  int epoch_size = 0;
44  std::string optimizer = "sgd";
45  // TODO(zhangchen-qinyinghua) More implement
46  // initializer=Uniform(0.01),
47  // numpy_batch_size=128,
48  // arg_params=None, aux_params=None,
49  // allow_extra_params=False,
50  // begin_epoch=0,
51  // **kwargs):
54 };
55 class FeedForward {
56  public:
57  explicit FeedForward(const FeedForwardConfig &conf) : conf_(conf) {}
58  void Predict();
59  void Score();
60  void Fit();
61  void Save();
62  void Load();
63  static FeedForward Create();
64 
65  private:
66  void InitParams();
67  void InitPredictor();
68  void InitIter();
69  void InitEvalIter();
70  FeedForwardConfig conf_;
71 };
72 
73 } // namespace cpp
74 } // namespace mxnet
75 
76 #endif // MXNET_CPP_MODEL_H_
77 
definition of symbol
namespace of mxnet
Definition: base.h:127
static Context cpu(int device_id=0)
Return a CPU context.
Definition: ndarray.h:81
FeedForwardConfig(const FeedForwardConfig &other)
Definition: model.h:52
Definition: model.h:55
FeedForward(const FeedForwardConfig &conf)
Definition: model.h:57
int num_epoch
Definition: model.h:42
std::vector< Context > ctx
Definition: model.h:41
Symbol symbol
Definition: model.h:40
Definition: model.h:39
int epoch_size
Definition: model.h:43
FeedForwardConfig()
Definition: model.h:53
std::string optimizer
Definition: model.h:44
Symbol interface.
Definition: symbol.h:72