mxnet
range.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_EXTENSION_RANGE_H_
27 #define MSHADOW_EXTENSION_RANGE_H_
28 
29 #include "../extension.h"
30 
31 namespace mshadow {
32 namespace expr {
42 template<typename DType>
43 struct RangeExp:
44  public Exp<RangeExp<DType>, DType, type::kMapper> {
45  const DType start_;
46  const DType stop_;
47  const DType step_;
48  const int repeat_;
50  RangeExp(DType start, DType stop, DType step, int repeat)
51  : start_(start), stop_(stop), step_(step), repeat_(repeat) {}
52 };
53 
54 template<typename DType>
55 inline RangeExp<DType>
56 range(DType start, DType stop, DType step = 1, int repeat = 1) {
57  return RangeExp<DType>(start, stop, step, repeat);
58 }
59 
60 //----------------------
61 // Execution plan
62 //----------------------
63 template<typename DType>
64 struct Plan<RangeExp<DType>, DType> {
65  public:
66  explicit Plan(const RangeExp<DType> &e)
67  : start_(e.start_),
68  stop_(e.stop_),
69  step_(e.step_),
70  repeat_(e.repeat_) {
71  }
72  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
73  return start_ + static_cast<DType>((static_cast<int>(x) / repeat_)) * step_;
74  }
75 
76  private:
77  const DType start_;
78  const DType stop_;
79  const DType step_;
80  const int repeat_;
81 };
82 
83 template<typename DType>
84 inline Plan<RangeExp<DType>, DType>
86  return Plan<RangeExp<DType>, DType>(exp);
87 }
88 
89 
90 template<typename DType>
91 inline int RangeOutSize(DType start, DType stop, DType step, int repeat) {
92  return repeat * ((stop - start - 1) / step + 1);
93 }
94 
95 template<>
96 inline int RangeOutSize<float>(float start, float stop, float step, int repeat) {
97  double d_start = static_cast<double>(start);
98  double d_stop = static_cast<double>(stop);
99  double d_step = static_cast<double>(step);
100  return repeat * static_cast<int>(ceil((d_stop - d_start) / d_step));
101 }
102 
103 template<>
104 inline int RangeOutSize<double>(double start, double stop, double step, int repeat) {
105  return repeat * static_cast<int>(ceil((stop - start) / step));
106 }
107 
108 
109 template<int dim, typename DType>
110 struct ShapeCheck<dim, RangeExp<DType> > {
111  inline static Shape<dim>
112  Check(const RangeExp<DType> &t) {
113  CHECK(dim == 1)
114  << "RangeExp only support 1 dimension output, received " << dim;
115  CHECK(t.step_ != 0)
116  << "RangeExp does not support step=0, received " << t.step_;
117  CHECK(t.repeat_ > 0)
118  << "RangeExp only supports repeat > 0, received " << t.repeat_;
119  if (t.step_ > 0) {
120  CHECK(t.start_ < t.stop_) << "RangeExp does not support (start, stop, step) = "
121  << "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")";
122  } else {
123  CHECK(t.start_ > t.stop_) << "RangeExp does not support (start, stop, step)= "
124  << "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")";
125  }
126  return Shape1(RangeOutSize<DType>(t.start_, t.stop_, t.step_, t.repeat_));
127  }
128 };
129 
130 template<typename DType>
131 struct ExpInfo<RangeExp<DType> > {
132  static const int kDim = 1;
133  static const int kDevMask = 0xffff;
134 };
135 } // namespace expr
136 } // namespace mshadow
137 #endif // MSHADOW_EXTENSION_RANGE_H_
Definition: expr_engine-inl.h:59
int RangeOutSize(DType start, DType stop, DType step, int repeat)
Definition: range.h:91
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: range.h:72
Plan(const RangeExp< DType > &e)
Definition: range.h:66
const DType start_
Definition: range.h:45
#define MSHADOW_XINLINE
Definition: base.h:223
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:263
const int repeat_
Definition: range.h:48
const DType stop_
Definition: range.h:46
int32_t index_t
type that will be used for index
Definition: base.h:336
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:207
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:365
const DType step_
Definition: range.h:47
RangeExp< DType > range(DType start, DType stop, DType step=1, int repeat=1)
Definition: range.h:56
int RangeOutSize< double >(double start, double stop, double step, int repeat)
Definition: range.h:104
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
RangeExp(DType start, DType stop, DType step, int repeat)
constructor
Definition: range.h:50
Generate a range vector similar to python: range(start, stop[, step][, repeat]). If step is positive...
Definition: range.h:43
overloaded + operator between half_t and bf16_t
Definition: base.h:327
static Shape< dim > Check(const RangeExp< DType > &t)
Definition: range.h:112
int RangeOutSize< float >(float start, float stop, float step, int repeat)
Definition: range.h:96