30 #ifndef MSHADOW_EXPR_SCALAR_INL_H_ 31 #define MSHADOW_EXPR_SCALAR_INL_H_ 33 #undef MSHADOW_EXPR_SCALAR_INL_H_ 39 template<
typename TA,
typename TB,
bool ltrans,
bool rtrans>
40 inline DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_>
43 return DotExp<TA, TB, ltrans, rtrans,
47 template<
typename TA,
typename TB,
bool ltrans,
bool rtrans>
51 return DotExp<TA, TB, ltrans, rtrans,
56 template<
typename E,
typename DType,
typename R,
int d>
62 template<
typename E,
typename DType,
typename R,
int d>
69 template<
typename OP,
typename TA,
int ta>
73 return MakeExp<OP>(lhs, rhs);
76 template<
typename OP,
typename TB,
int tb>
80 return MakeExp<OP>(lhs, rhs);
86 F(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
const ScalarExp<MSHADOW_SCALAR_> &rhs) {
87 return MakeExp<OP>(lhs, rhs);
91 template<
typename TA,
int ta>
95 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
96 return MakeExp<op::plus>(lhs, rhs);
99 template<
typename TA,
int ta>
103 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
104 return MakeExp<op::minus>(lhs, rhs);
107 template<
typename TA,
int ta>
111 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
112 return MakeExp<op::mul>(lhs, rhs);
115 template<
typename TA,
int ta>
119 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
120 return MakeExp<op::div>(lhs, rhs);
124 template<
typename TB,
int tb>
127 operator+(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
129 return MakeExp<op::plus>(lhs, rhs);
132 template<
typename TB,
int tb>
135 operator-(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
137 return MakeExp<op::minus>(lhs, rhs);
140 template<
typename TB,
int tb>
143 operator*(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
145 return MakeExp<op::mul>(lhs, rhs);
148 template<
typename TB,
int tb>
152 return MakeExp<op::div>(lhs, rhs);
156 inline BinaryMapExp<op::plus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
158 operator+(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
159 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
160 return MakeExp<op::plus>(lhs, rhs);
163 inline BinaryMapExp<op::minus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
165 operator-(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
166 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
167 return MakeExp<op::minus>(lhs, rhs);
170 inline BinaryMapExp<op::mul, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
172 operator*(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
173 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
174 return MakeExp<op::mul>(lhs, rhs);
177 inline BinaryMapExp<op::div, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
179 operator/(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
const ScalarExp<MSHADOW_SCALAR_> &rhs) {
180 return MakeExp<op::div>(lhs, rhs);
184 #endif // MSHADOW_EXPR_SCALAR_INL_H_ const int kMapper
expression contains element-wise tensor operations, map a expression to same shape ...
Definition: expression.h:51
#define MSHADOW_SCALAR_
Definition: tensor.h:1097
const SrcExp & src_
source operand
Definition: reduceto1d.h:46
binary map expression lhs [op] rhs
Definition: expression.h:335
DType scale_
source operand, scale of the
Definition: reduceto1d.h:48
const TB & rhs_
right operand
Definition: expression.h:230
reduction to 1 dimension tensor input: Tensor<Device,k>: ishape output: Tensor<Device,1> shape[0] = ishape[dimkeep];
Definition: reduceto1d.h:42
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:72
DType scale_
scale over result
Definition: expression.h:232
matrix multiplication expression dot(lhs[.T], rhs[.T])
Definition: expression.h:225
scalar expression
Definition: expression.h:96
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const TA & lhs_
left operand
Definition: expression.h:228
overloaded + operator between half_t and bf16_t
Definition: base.h:327
DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > operator*(const DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > &lhs, MSHADOW_SCALAR_ rhs)
dot operator def
Definition: expr_scalar-inl.h:41