mxnet
expression.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_EXPRESSION_H_
26 #define MSHADOW_EXPRESSION_H_
27 #include "./base.h"
28 
29 namespace mshadow {
36 namespace expr {
38 namespace type {
39 // type expression type are defined as bitmask
40 // subtype relationshop kRValue < kMapper < kPull < kComplex
45 const int kRValue = 0;
50 const int kMapper = 1;
56 const int kChainer = 3;
58 const int kComplex = 7;
59 } // namespace type
67 template<typename Saver, typename RValue, typename DType>
68 struct ExpEngine;
70 // template<typename EType>
71 // inline static void Eval(RValue *dst, const EType &exp);
78 template<typename SubType, typename DType, int exp_type>
79 struct Exp {
80  public:
82  inline const SubType& self(void) const {
83  return *static_cast<const SubType*>(this);
84  }
86  inline SubType* ptrself(void) {
87  return static_cast<SubType*>(this);
88  }
89 };
94 template<typename DType>
95 struct ScalarExp: public Exp<ScalarExp<DType>, DType, type::kMapper> {
97  DType scalar_;
99  ScalarExp(DType scalar) : scalar_(scalar) {} // NOLINT(*)
100 };
102 template<typename DType>
103 inline ScalarExp<DType> scalar(DType s) {
104  return ScalarExp<DType>(s);
105 }
113 template<typename DstDType, typename SrcDType, typename EType, int etype>
114 struct TypecastExp:
115  public Exp<TypecastExp<DstDType, SrcDType, EType, etype>,
116  DstDType, etype> {
118  const EType &exp;
120  explicit TypecastExp(const EType &e) : exp(e) {}
121 };
123 template<typename DstDType, typename SrcDType,
124  typename EType, int etype>
128 }
130 template<typename EType, typename DType>
131 struct TransposeExp: public Exp<TransposeExp<EType, DType>,
132  DType, type::kChainer> {
134  const EType &exp;
136  explicit TransposeExp(const EType &e) : exp(e) {}
138  inline const EType &T(void) const {
139  return exp;
140  }
141 };
147 template<typename Container, typename DType>
148 class RValueExp: public Exp<Container, DType, type::kRValue> {
149  public:
154  inline const TransposeExp<Container, DType> T(void) const {
155  return TransposeExp<Container, DType>(this->self());
156  }
158  inline Container &operator+=(DType s) {
159  ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
160  return *(this->ptrself());
161  }
163  inline Container &operator-=(DType s) {
164  ExpEngine<sv::minusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
165  return *(this->ptrself());
166  }
168  inline Container &operator*=(DType s) {
169  ExpEngine<sv::multo, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
170  return *(this->ptrself());
171  }
173  inline Container &operator/=(DType s) {
174  ExpEngine<sv::divto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
175  return *(this->ptrself());
176  }
178  inline Container &__assign(DType s) {
179  ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
180  return *(this->ptrself());
181  }
183  template<typename E, int etype>
184  inline Container &__assign(const Exp<E, DType, etype> &exp) {
185  ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), exp.self());
186  return *(this->ptrself());
187  }
189  inline Container &__assign(const Exp<Container, DType, type::kRValue> &exp);
191  template<typename E, int etype>
192  inline Container &operator+=(const Exp<E, DType, etype> &exp) {
193  ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), exp.self());
194  return *(this->ptrself());
195  }
197  template<typename E, int etype>
198  inline Container &operator-=(const Exp<E, DType, etype> &exp) {
200  return *(this->ptrself());
201  }
203  template<typename E, int etype>
204  inline Container &operator*=(const Exp<E, DType, etype> &exp) {
205  ExpEngine<sv::multo, Container, DType>::Eval(this->ptrself(), exp.self());
206  return *(this->ptrself());
207  }
209  template<typename E, int etype>
210  inline Container &operator/=(const Exp<E, DType, etype> &exp) {
211  ExpEngine<sv::divto, Container, DType>::Eval(this->ptrself(), exp.self());
212  return *(this->ptrself());
213  }
214 };
223 template<typename TA, typename TB, bool ltrans, bool rtrans, typename DType>
224 struct DotExp: public Exp<DotExp<TA, TB, ltrans, rtrans, DType>,
225  DType, type::kComplex> {
227  const TA &lhs_;
229  const TB &rhs_;
231  DType scale_;
233  explicit DotExp(const TA &lhs, const TB &rhs, DType scale)
234  : lhs_(lhs), rhs_(rhs), scale_(scale) {}
235 };
236 // definition of dot expression
238 template<typename TA, typename TB, typename DType>
241  return DotExp<TA, TB, false, false, DType>(lhs.self(), rhs.self(), DType(1.0f));
242 }
244 template<typename TA, typename TB, typename DType>
247  return DotExp<TA, TB, true, false, DType>(lhs.exp, rhs.self(), DType(1.0f));
248 }
250 template<typename TA, typename TB, typename DType>
253  return DotExp<TA, TB, false, true, DType>(lhs.self(), rhs.exp, DType(1.0f));
254 }
256 template<typename TA, typename TB, typename DType>
259  return DotExp<TA, TB, true, true, DType>(lhs.exp, rhs.exp, DType(1.0f));
260 }
262 template<bool transpose_left, bool transpose_right, typename TA, typename TB, typename DType>
266  lhs.self(), rhs.self(), DType(1.0f));
267 }
268 //---------------
269 // TernaryMapExp
270 // --------------
278 template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
279 struct TernaryMapExp: public Exp<TernaryMapExp<OP, TA, TB, TC, DType, etype>,
280  DType, etype> {
282  const TA &item1_;
284  const TB &item2_;
286  const TC &item3_;
288  explicit TernaryMapExp(const TA &item1, const TB &item2, const TC &item3)
289  :item1_(item1), item2_(item2), item3_(item3) {}
290 };
291 
293 template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
295 MakeExp(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
296  const Exp<TC, DType, tc> &item3) {
297  return TernaryMapExp<OP, TA, TB, TC, DType,
298  (ta|tb|tc|type::kMapper)>(item1.self(), item2.self(), item3.self());
299 }
316 // Ternary
317 template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
319 F(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
320  const Exp<TC, DType, tc> &item3) {
321  return MakeExp<OP>(item1, item2, item3);
322 }
323 //---------------
324 // BinaryMapExp
325 // --------------
333 template<typename OP, typename TA, typename TB, typename DType, int etype>
334 struct BinaryMapExp: public Exp<BinaryMapExp<OP, TA, TB, DType, etype>,
335  DType, etype> {
337  const TA &lhs_;
339  const TB &rhs_;
341  explicit BinaryMapExp(const TA &lhs, const TB &rhs)
342  :lhs_(lhs), rhs_(rhs) {}
343 };
344 
346 template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
349  return BinaryMapExp<OP, TA, TB, DType,
350  (ta|tb|type::kMapper)>(lhs.self(), rhs.self());
351 }
364 template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
366 F(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
367  return MakeExp<OP>(lhs, rhs);
368 }
369 // operator rules
371 template<typename TA, typename TB, typename DType, int ta, int tb>
374  return MakeExp<op::plus>(lhs, rhs);
375 }
377 template<typename TA, typename TB, typename DType, int ta, int tb>
380  return MakeExp<op::minus>(lhs, rhs);
381 }
383 template<typename TA, typename TB, typename DType, int ta, int tb>
386  return MakeExp<op::mul>(lhs, rhs);
387 }
389 template<typename TA, typename TB, typename DType, int ta, int tb>
392  return MakeExp<op::div>(lhs, rhs);
393 }
394 //---------------
395 // UnaryMapExp
396 // --------------
403 template<typename OP, typename TA, typename DType, int etype>
404 struct UnaryMapExp: public Exp<UnaryMapExp<OP, TA, DType, etype>,
405  DType, etype> {
407  const TA &src_;
409  explicit UnaryMapExp(const TA &src) : src_(src) {}
410 };
411 
413 template<typename OP, typename TA, typename DType, int ta>
417 }
427 template<typename OP, typename TA, typename DType, int ta>
429 F(const Exp<TA, DType, ta> &src) {
430  return MakeExp<OP>(src);
431 }
432 } // namespace expr
433 } // namespace mshadow
434 #endif // MSHADOW_EXPRESSION_H_
const int kRValue
this expression directly correspnds to a data class, can be used to assign data
Definition: expression.h:45
const int kMapper
expression contains element-wise tensor operations, map a expression to same shape ...
Definition: expression.h:50
DotExp< TA, TB, transpose_left, transpose_right, DType > batch_dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
batch_dot operator def
Definition: expression.h:264
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:103
BinaryMapExp< op::minus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator-(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:101
BinaryMapExp< op::plus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator+(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:93
const SubType & self(void) const
Definition: expression.h:82
UnaryMapExp(const TA &src)
constructor
Definition: expression.h:409
const TB & rhs_
right operand
Definition: expression.h:339
SubType * ptrself(void)
Definition: expression.h:86
Container & operator/=(const Exp< E, DType, etype > &exp)
implementation of operator/=
Definition: expression.h:210
const int kComplex
othercase: e.g dot product
Definition: expression.h:58
ternary map expression
Definition: expression.h:279
const EType & T(void) const
transpose expression
Definition: expression.h:138
binary map expression lhs [op] rhs
Definition: expression.h:334
DotExp< TA, TB, false, false, DType > dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
dot operator def
Definition: expression.h:240
TernaryMapExp(const TA &item1, const TB &item2, const TC &item3)
constructor
Definition: expression.h:288
TypecastExp(const EType &e)
constructor
Definition: expression.h:120
BinaryMapExp(const TA &lhs, const TB &rhs)
constructor
Definition: expression.h:341
TypecastExp< DstDType, SrcDType, EType,(etype|type::kMapper)> tcast(const Exp< EType, SrcDType, etype > &exp)
create an scalar expression
Definition: expression.h:126
Container & operator*=(const Exp< E, DType, etype > &exp)
implementation of operator*=
Definition: expression.h:204
base class of all rvalues
Definition: expression.h:148
DType scalar_
scalar value
Definition: expression.h:97
const TB & rhs_
right operand
Definition: expression.h:229
const EType & exp
expression to be transposed
Definition: expression.h:134
const int kChainer
expression that can be chained with other expressiones Usually it have function Eval(i,j) defined, which pulls the result (i, j) from input expression and output the result at certain position.
Definition: expression.h:56
const TB & item2_
second operand
Definition: expression.h:284
ScalarExp(DType scalar)
implicit constructor, MUST NOT BE explicit
Definition: expression.h:99
Container & operator-=(DType s)
operator overload
Definition: expression.h:163
Container & operator-=(const Exp< E, DType, etype > &exp)
implementation of operator-=
Definition: expression.h:198
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:71
DType scale_
scale over result
Definition: expression.h:231
DotExp(const TA &lhs, const TB &rhs, DType scale)
constructor
Definition: expression.h:233
const TransposeExp< Container, DType > T(void) const
transpose of a matrix
Definition: expression.h:154
Container & operator/=(DType s)
operator overload
Definition: expression.h:173
const TA & item1_
first operand
Definition: expression.h:282
typecast expression, cast the type of elements
Definition: expression.h:114
represent a transpose expression of a container
Definition: expression.h:131
Container & operator+=(DType s)
operator overload
Definition: expression.h:158
const TA & src_
source expression
Definition: expression.h:407
unary map expression op(src)
Definition: expression.h:404
matrix multiplication expression dot(lhs[.T], rhs[.T])
Definition: expression.h:224
scalar expression
Definition: expression.h:95
TernaryMapExp< OP, TA, TB, TC, DType,(ta|tb|tc|type::kMapper)> MakeExp(const Exp< TA, DType, ta > &item1, const Exp< TB, DType, tb > &item2, const Exp< TC, DType, tc > &item3)
make expression
Definition: expression.h:295
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
const EType & exp
expression to be typecasted
Definition: expression.h:118
const TC & item3_
third operand
Definition: expression.h:286
Container & operator*=(DType s)
operator overload
Definition: expression.h:168
BinaryMapExp< op::div, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator/(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:117
const TA & lhs_
left operand
Definition: expression.h:227
const TA & lhs_
left operand
Definition: expression.h:337
overloaded + operator between half_t and bf16_t
Definition: base.h:334
DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > operator*(const DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > &lhs, MSHADOW_SCALAR_ rhs)
dot operator def
Definition: expr_scalar-inl.h:40
the engine that dispatches simple operations
Definition: expr_engine-inl.h:460
Container & __assign(DType s)
operator overload
Definition: expression.h:178
Container & operator+=(const Exp< E, DType, etype > &exp)
implementation of operator+=
Definition: expression.h:192
static void Eval(RV *dst, const Exp< E, DType, type::kMapper > &exp)
Definition: expr_engine-inl.h:462
TransposeExp(const EType &e)
constructor
Definition: expression.h:136
Container & __assign(const Exp< E, DType, etype > &exp)
we can not define container = container
Definition: expression.h:184