25 #ifndef MSHADOW_EXPR_ENGINE_INL_H_ 26 #define MSHADOW_EXPR_ENGINE_INL_H_ 29 #include "./logging.h" 42 template<
typename SubType,
typename SrcExp,
int dim,
typename DType>
44 :
public Exp<MakeTensorExp<SubType, SrcExp, dim, DType>,
45 DType, type::kChainer> {
50 return *
static_cast<const SubType*
>(
this);
57 template<
typename ExpType,
typename DType>
67 template <
typename Device,
int dim,
typename DType>
71 : dptr_(t.dptr_), stride_(t.stride_) {}
74 return dptr_[y * stride_ + x];
78 return dptr_[y * stride_ + x];
86 template <
typename Device,
typename DType>
101 template<
typename DType>
113 template<
typename DstDType,
typename SrcDType,
114 typename EType,
int etype>
119 return DstDType(src_.Eval(y, x));
127 template<
typename OP,
typename TA,
typename TB,
typename TC,
int etype,
typename DType>
132 : item1_(item1), item2_(item2), item3_(item3) {}
134 return OP::Map(item1_.Eval(y, x), item2_.Eval(y, x), item3_.Eval(y, x));
143 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
147 : lhs_(lhs), rhs_(rhs) {}
149 return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x));
157 template<
typename OP,
typename TA,
int etype,
typename DType>
162 return OP::Map(src_.Eval(y, x));
169 template<
typename SubType,
typename SrcExp,
int dim,
typename DType>
174 return src_.Eval(y, x);
181 template<
typename EType,
typename DType>
186 return src_.Eval(x, y);
195 template<
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
199 template<
typename OP,
typename TA,
typename TB,
typename TC,
typename DType,
int etype>
203 template<
typename DType>
208 template<
typename DstDType,
typename SrcDType,
typename EType,
int etype>
211 return Plan<TypecastExp<DstDType, SrcDType, EType, etype>, DstDType>(
MakePlan(e.
exp));
214 template<
typename T,
typename DType>
219 template<
typename T,
typename DType>
222 return Plan<TransposeExp<T, DType>, DType>(
MakePlan(e.
exp));
225 template<
typename T,
typename SrcExp,
int dim,
typename DType>
231 template<
typename OP,
typename TA,
typename DType,
int etype>
234 return Plan<UnaryMapExp<OP, TA, DType, etype>, DType>(
MakePlan(e.
src_));
237 template<
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
238 inline Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType>
240 return Plan<BinaryMapExp<OP, TA, TB, DType, etype>,
245 template<
typename OP,
typename TA,
typename TB,
typename TC,
typename DType,
int etype>
246 inline Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, DType>
248 return Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>,
263 static const int kDim = -1;
264 static const int kDevMask = 0;
266 template<
typename DType>
268 static const int kDim = 0;
269 static const int kDevMask = 0xffff;
271 template<
typename E,
typename DType>
276 template<
typename DstDType,
typename SrcDType,
typename EType,
int etype>
281 template<
typename Device,
int dim,
typename DType>
283 static const int kDim = dim;
284 static const int kDevMask = Device::kDevMask;
286 template<
typename T,
typename SrcExp,
int dim,
typename DType>
289 static const int kDim = kDimSrc >= 0 ? dim : -1;
292 template<
typename OP,
typename TA,
typename DType,
int etype>
297 template<
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
301 static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ?\
304 ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1;
307 template<
typename OP,
typename TA,
typename TB,
typename TC,
typename DType,
int etype>
312 static const int kDim = kDimItem1;
317 template<
typename Device,
int dim,
typename DType,
typename E>
324 static const bool kMapPass = (kExpDim == 0 || kExpDim == dim) && kDevPass;
326 static const bool kRedPass = (kExpDim > dim) && kDevPass;
344 template<
typename Device,
typename E>
348 template<
int dim,
typename Device,
typename DType>
363 template<
int dim,
typename E>
367 template<
int dim,
typename DType>
372 for (
int i = 0; i < dim; ++i) {
378 template<
int dim,
typename DstDType,
typename SrcDType,
typename EType,
int etype>
385 template<
int dim,
typename E,
typename DType>
390 std::swap(s[0], s[1]);
394 template<
int dim,
typename Device,
typename DType>
400 template<
int dim,
typename SrcExp,
typename T,
typename DType>
407 template<
int dim,
typename OP,
typename TA,
typename DType,
int etype>
415 template<
int dim,
typename OP,
typename TA,
typename TB,
416 typename DType,
int etype>
422 if (shape1[0] == 0)
return shape2;
423 if (shape2[0] == 0)
return shape1;
424 CHECK_EQ(shape1, shape2) <<
"BinaryMapExp: Shapes of operands are not the same, " <<
425 "Shape1=" << shape1 <<
", Shape2=" << shape2;
430 template<
int dim,
typename OP,
typename TA,
typename TB,
typename TC,
431 typename DType,
int etype>
438 bool same = (shape1 == shape2) && (shape2 == shape3);
439 CHECK(same) <<
"TernaryMapExp: Shapes of operands are not the same, " <<
440 "Shape1=" << shape1 <<
", Shape2=" << shape2 <<
", Shape3=" << shape3;
454 template<
typename SV,
typename RV,
typename E,
typename DType>
456 inline static void Eval(RV *dst,
const E &exp);
459 template<
typename SV,
typename RV,
typename DType>
462 inline static void Eval(RV *dst,
464 MapExp<SV>(dst, exp);
467 inline static void Eval(RV *dst,
469 MapExp<SV>(dst, exp);
472 inline static void Eval(RV *dst,
474 MapExp<SV>(dst, exp);
477 inline static void Eval(RV *dst,
482 template<
typename SV,
typename Device,
int dim,
int ldim,
483 int rdim,
bool ltrans,
bool rtrans,
typename DType>
485 Tensor<Device, dim, DType>,
487 Tensor<Device, rdim, DType>,
488 ltrans, rtrans, DType>,
493 ltrans, rtrans, DType> &exp) {
495 ltrans, rtrans, DType>::Eval(dst, exp.lhs_, exp.rhs_, exp.scale_);
500 #endif // MSHADOW_EXPR_ENGINE_INL_H_ Plan(const Plan< TA, DType > &src)
Definition: expr_engine-inl.h:160
MSHADOW_XINLINE const DType & Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:93
static Shape< dim > Check(const UnaryMapExp< OP, TA, DType, etype > &t)
Definition: expr_engine-inl.h:409
static void Eval(RV *dst, const Exp< E, DType, type::kRValue > &exp)
Definition: expr_engine-inl.h:472
static Shape< dim > Check(const MakeTensorExp< T, SrcExp, dim, DType > &t)
Definition: expr_engine-inl.h:403
static Shape< dim > Check(const Tensor< Device, dim, DType > &t)
Definition: expr_engine-inl.h:396
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:103
static Shape< dim > Check(const BinaryMapExp< OP, TA, TB, DType, etype > &t)
Definition: expr_engine-inl.h:419
const Container & self(void) const
Definition: expression.h:82
Definition: expr_engine-inl.h:58
Plan(const Tensor< Device, 1, DType > &t)
Definition: expr_engine-inl.h:89
used to help static type check
Definition: expr_engine-inl.h:330
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:133
template to do type check
Definition: expr_engine-inl.h:318
const TB & rhs_
right operand
Definition: expression.h:339
Plan(const Tensor< Device, dim, DType > &t)
Definition: expr_engine-inl.h:70
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:436
ternary map expression
Definition: expression.h:279
static void Error_All_Tensor_in_Exp_Must_Have_Same_Type(void)
Definition: expr_engine-inl.h:336
binary map expression lhs [op] rhs
Definition: expression.h:334
Plan(DType scalar)
Definition: expr_engine-inl.h:104
mshadow::expr::ExpComplexEngine< SV, Tensor< Device, dim, DType >, DotExp< Tensor< Device, ldim, DType >, Tensor< Device, rdim, DType >, ltrans, rtrans, DType >, DType >::Eval static void Eval(Tensor< Device, dim, DType > *dst, const DotExp< Tensor< Device, ldim, DType >, Tensor< Device, rdim, DType >, ltrans, rtrans, DType > &exp)
Definition: expr_engine-inl.h:490
static void Eval(RV *dst, const E &exp)
base class of all rvalues
Definition: expression.h:148
Definition: dot_engine-inl.h:70
DType scalar_
scalar value
Definition: expression.h:97
MSHADOW_XINLINE DType & REval(index_t y, index_t x)
Definition: expr_engine-inl.h:73
const EType & exp
expression to be transposed
Definition: expression.h:134
static void Error_TypeCheck_Not_Pass_For_Reduce_Exp(void)
Definition: expr_engine-inl.h:337
static Shape< dim > Check(const E &t)
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
static void Eval(RV *dst, const Exp< E, DType, type::kChainer > &exp)
Definition: expr_engine-inl.h:467
#define MSHADOW_XINLINE
Definition: base.h:230
const TB & item2_
second operand
Definition: expression.h:284
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:262
Definition: expr_engine-inl.h:345
definitions of abstract expressions and expressions template
static void Eval(RV *dst, const Exp< E, DType, type::kComplex > &exp)
Definition: expr_engine-inl.h:477
static Shape< dim > Check(const TernaryMapExp< OP, TA, TB, TC, DType, etype > &t)
Definition: expr_engine-inl.h:434
static void Error_Expression_Does_Not_Meet_Dimension_Req(void)
Definition: expr_engine-inl.h:338
int32_t index_t
type that will be used for index
Definition: base.h:343
Plan(const Plan< EType, DType > &src)
Definition: expr_engine-inl.h:184
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:161
static Shape< dim > Check(const TypecastExp< DstDType, SrcDType, EType, etype > &exp)
Definition: expr_engine-inl.h:381
Plan(const Plan< TA, DType > &item1, const Plan< TB, DType > &item2, const Plan< TC, DType > &item3)
Definition: expr_engine-inl.h:130
const TA & item1_
first operand
Definition: expression.h:282
typecast expression, cast the type of elements
Definition: expression.h:114
static Shape< dim > Check(const ScalarExp< DType > &exp)
Definition: expr_engine-inl.h:369
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:364
represent a transpose expression of a container
Definition: expression.h:131
some engine that evaluate complex expression
Definition: expr_engine-inl.h:455
const TA & src_
source expression
Definition: expression.h:407
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:105
unary map expression op(src)
Definition: expression.h:404
matrix multiplication expression dot(lhs[.T], rhs[.T])
Definition: expression.h:224
Plan(const Plan< EType, SrcDType > &src)
Definition: expr_engine-inl.h:117
scalar expression
Definition: expression.h:95
Plan(const Plan< SubType, DType > &src)
Definition: expr_engine-inl.h:172
MSHADOW_XINLINE DstDType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:118
const SubType & real_self(void) const
true self of subtype
Definition: expr_engine-inl.h:49
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
const EType & exp
expression to be typecasted
Definition: expression.h:118
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
const TC & item3_
third operand
Definition: expression.h:286
MSHADOW_XINLINE const DType & Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:77
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
const TA & lhs_
left operand
Definition: expression.h:337
overloaded + operator between half_t and bf16_t
Definition: base.h:334
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:148
static Shape< dim > Check(const TransposeExp< E, DType > &e)
Definition: expr_engine-inl.h:387
the engine that dispatches simple operations
Definition: expr_engine-inl.h:460
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:185
general tensor
Definition: tensor.h:420
static void Eval(RV *dst, const Exp< E, DType, type::kMapper > &exp)
Definition: expr_engine-inl.h:462
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation ...
Definition: tensor.h:446
definitions of how Matrix Multiplications can be evaluated
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: expr_engine-inl.h:146
MSHADOW_XINLINE DType & REval(index_t y, index_t x)
Definition: expr_engine-inl.h:90
static Stream< Device > * Get(const Tensor< Device, dim, DType > &t)
Definition: expr_engine-inl.h:350
computaion stream structure, used for asynchronous computations
Definition: tensor.h:383
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:173