mxnet
Main Page
Namespaces
Classes
Files
File List
File Members
cpp-package
include
mxnet-cpp
model.h
Go to the documentation of this file.
1
/*
2
* Licensed to the Apache Software Foundation (ASF) under one
3
* or more contributor license agreements. See the NOTICE file
4
* distributed with this work for additional information
5
* regarding copyright ownership. The ASF licenses this file
6
* to you under the Apache License, Version 2.0 (the
7
* "License"); you may not use this file except in compliance
8
* with the License. You may obtain a copy of the License at
9
*
10
* http://www.apache.org/licenses/LICENSE-2.0
11
*
12
* Unless required by applicable law or agreed to in writing,
13
* software distributed under the License is distributed on an
14
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
* KIND, either express or implied. See the License for the
16
* specific language governing permissions and limitations
17
* under the License.
18
*/
19
26
#ifndef MXNET_CPP_MODEL_H_
27
#define MXNET_CPP_MODEL_H_
28
29
#include <string>
30
#include <vector>
31
#include "
mxnet-cpp/base.h
"
32
#include "
mxnet-cpp/symbol.h
"
33
#include "
mxnet-cpp/ndarray.h
"
34
35
namespace
mxnet
{
36
namespace
cpp {
37
38
struct
FeedForwardConfig
{
39
Symbol
symbol
;
40
std::vector<Context>
ctx
= {
Context::cpu
()};
41
int
num_epoch
= 0;
42
int
epoch_size
= 0;
43
std::string
optimizer
=
"sgd"
;
44
// TODO(zhangchen-qinyinghua) More implement
45
// initializer=Uniform(0.01),
46
// numpy_batch_size=128,
47
// arg_params=None, aux_params=None,
48
// allow_extra_params=False,
49
// begin_epoch=0,
50
// **kwargs):
51
FeedForwardConfig
(
const
FeedForwardConfig
&other) {}
52
FeedForwardConfig
() {}
53
};
54
class
FeedForward
{
55
public
:
56
explicit
FeedForward
(
const
FeedForwardConfig
&conf) : conf_(conf) {}
57
void
Predict();
58
void
Score();
59
void
Fit();
60
void
Save();
61
void
Load();
62
static
FeedForward
Create();
63
64
private
:
65
void
InitParams();
66
void
InitPredictor();
67
void
InitIter();
68
void
InitEvalIter();
69
FeedForwardConfig
conf_;
70
};
71
72
}
// namespace cpp
73
}
// namespace mxnet
74
75
#endif // MXNET_CPP_MODEL_H_
76
symbol.h
definition of symbol
mxnet
namespace of mxnet
Definition:
base.h:126
mxnet::cpp::Context::cpu
static Context cpu(int device_id=0)
Return a CPU context.
Definition:
ndarray.h:80
mxnet::cpp::FeedForwardConfig::FeedForwardConfig
FeedForwardConfig(const FeedForwardConfig &other)
Definition:
model.h:51
mxnet::cpp::FeedForward
Definition:
model.h:54
ndarray.h
base.h
mxnet::cpp::FeedForward::FeedForward
FeedForward(const FeedForwardConfig &conf)
Definition:
model.h:56
mxnet::cpp::FeedForwardConfig::num_epoch
int num_epoch
Definition:
model.h:41
mxnet::cpp::FeedForwardConfig::ctx
std::vector< Context > ctx
Definition:
model.h:40
mxnet::cpp::FeedForwardConfig::symbol
Symbol symbol
Definition:
model.h:39
mxnet::cpp::FeedForwardConfig
Definition:
model.h:38
mxnet::cpp::FeedForwardConfig::epoch_size
int epoch_size
Definition:
model.h:42
mxnet::cpp::FeedForwardConfig::FeedForwardConfig
FeedForwardConfig()
Definition:
model.h:52
mxnet::cpp::FeedForwardConfig::optimizer
std::string optimizer
Definition:
model.h:43
mxnet::cpp::Symbol
Symbol interface.
Definition:
symbol.h:71
Generated on Thu Sep 19 2019 13:08:39 for mxnet by
1.8.11