mxnet
Main Page
Namespaces
Classes
Files
File List
File Members
cpp-package
include
mxnet-cpp
model.h
Go to the documentation of this file.
1
/*
2
* Licensed to the Apache Software Foundation (ASF) under one
3
* or more contributor license agreements. See the NOTICE file
4
* distributed with this work for additional information
5
* regarding copyright ownership. The ASF licenses this file
6
* to you under the Apache License, Version 2.0 (the
7
* "License"); you may not use this file except in compliance
8
* with the License. You may obtain a copy of the License at
9
*
10
* http://www.apache.org/licenses/LICENSE-2.0
11
*
12
* Unless required by applicable law or agreed to in writing,
13
* software distributed under the License is distributed on an
14
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
* KIND, either express or implied. See the License for the
16
* specific language governing permissions and limitations
17
* under the License.
18
*/
19
27
#ifndef MXNET_CPP_MODEL_H_
28
#define MXNET_CPP_MODEL_H_
29
30
#include <string>
31
#include <vector>
32
#include "
mxnet-cpp/base.h
"
33
#include "
mxnet-cpp/symbol.h
"
34
#include "
mxnet-cpp/ndarray.h
"
35
36
namespace
mxnet
{
37
namespace
cpp {
38
39
struct
FeedForwardConfig
{
40
Symbol
symbol
;
41
std::vector<Context>
ctx
= {
Context::cpu
()};
42
int
num_epoch
= 0;
43
int
epoch_size
= 0;
44
std::string
optimizer
=
"sgd"
;
45
// TODO(zhangchen-qinyinghua) More implement
46
// initializer=Uniform(0.01),
47
// numpy_batch_size=128,
48
// arg_params=None, aux_params=None,
49
// allow_extra_params=False,
50
// begin_epoch=0,
51
// **kwargs):
52
FeedForwardConfig
(
const
FeedForwardConfig
&other) {}
53
FeedForwardConfig
() {}
54
};
55
class
FeedForward
{
56
public
:
57
explicit
FeedForward
(
const
FeedForwardConfig
&conf) : conf_(conf) {}
58
void
Predict();
59
void
Score();
60
void
Fit();
61
void
Save();
62
void
Load();
63
static
FeedForward
Create();
64
65
private
:
66
void
InitParams();
67
void
InitPredictor();
68
void
InitIter();
69
void
InitEvalIter();
70
FeedForwardConfig
conf_;
71
};
72
73
}
// namespace cpp
74
}
// namespace mxnet
75
76
#endif // MXNET_CPP_MODEL_H_
77
symbol.h
definition of symbol
mxnet
namespace of mxnet
Definition:
base.h:118
mxnet::cpp::Context::cpu
static Context cpu(int device_id=0)
Return a CPU context.
Definition:
ndarray.h:81
mxnet::cpp::FeedForwardConfig::FeedForwardConfig
FeedForwardConfig(const FeedForwardConfig &other)
Definition:
model.h:52
mxnet::cpp::FeedForward
Definition:
model.h:55
ndarray.h
base.h
mxnet::cpp::FeedForward::FeedForward
FeedForward(const FeedForwardConfig &conf)
Definition:
model.h:57
mxnet::cpp::FeedForwardConfig::num_epoch
int num_epoch
Definition:
model.h:42
mxnet::cpp::FeedForwardConfig::ctx
std::vector< Context > ctx
Definition:
model.h:41
mxnet::cpp::FeedForwardConfig::symbol
Symbol symbol
Definition:
model.h:40
mxnet::cpp::FeedForwardConfig
Definition:
model.h:39
mxnet::cpp::FeedForwardConfig::epoch_size
int epoch_size
Definition:
model.h:43
mxnet::cpp::FeedForwardConfig::FeedForwardConfig
FeedForwardConfig()
Definition:
model.h:53
mxnet::cpp::FeedForwardConfig::optimizer
std::string optimizer
Definition:
model.h:44
mxnet::cpp::Symbol
Symbol interface.
Definition:
symbol.h:72
Generated on Thu Sep 19 2019 12:47:36 for mxnet by
1.8.11