29 #ifndef MSHADOW_EXPR_SCALAR_INL_H_ 30 #define MSHADOW_EXPR_SCALAR_INL_H_ 32 #undef MSHADOW_EXPR_SCALAR_INL_H_ 38 template<
typename TA,
typename TB,
bool ltrans,
bool rtrans>
39 inline DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_>
42 return DotExp<TA, TB, ltrans, rtrans,
46 template<
typename TA,
typename TB,
bool ltrans,
bool rtrans>
50 return DotExp<TA, TB, ltrans, rtrans,
55 template<
typename E,
typename DType,
typename R,
int d>
61 template<
typename E,
typename DType,
typename R,
int d>
68 template<
typename OP,
typename TA,
int ta>
72 return MakeExp<OP>(lhs, rhs);
75 template<
typename OP,
typename TB,
int tb>
79 return MakeExp<OP>(lhs, rhs);
85 F(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
const ScalarExp<MSHADOW_SCALAR_> &rhs) {
86 return MakeExp<OP>(lhs, rhs);
90 template<
typename TA,
int ta>
94 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
95 return MakeExp<op::plus>(lhs, rhs);
98 template<
typename TA,
int ta>
102 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
103 return MakeExp<op::minus>(lhs, rhs);
106 template<
typename TA,
int ta>
110 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
111 return MakeExp<op::mul>(lhs, rhs);
114 template<
typename TA,
int ta>
118 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
119 return MakeExp<op::div>(lhs, rhs);
123 template<
typename TB,
int tb>
126 operator+(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
128 return MakeExp<op::plus>(lhs, rhs);
131 template<
typename TB,
int tb>
134 operator-(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
136 return MakeExp<op::minus>(lhs, rhs);
139 template<
typename TB,
int tb>
142 operator*(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
144 return MakeExp<op::mul>(lhs, rhs);
147 template<
typename TB,
int tb>
151 return MakeExp<op::div>(lhs, rhs);
155 inline BinaryMapExp<op::plus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
157 operator+(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
158 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
159 return MakeExp<op::plus>(lhs, rhs);
162 inline BinaryMapExp<op::minus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
164 operator-(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
165 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
166 return MakeExp<op::minus>(lhs, rhs);
169 inline BinaryMapExp<op::mul, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
171 operator*(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
172 const ScalarExp<MSHADOW_SCALAR_> &rhs) {
173 return MakeExp<op::mul>(lhs, rhs);
176 inline BinaryMapExp<op::div, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
178 operator/(
const ScalarExp<MSHADOW_SCALAR_> &lhs,
const ScalarExp<MSHADOW_SCALAR_> &rhs) {
179 return MakeExp<op::div>(lhs, rhs);
183 #endif // MSHADOW_EXPR_SCALAR_INL_H_ const int kMapper
expression contains element-wise tensor operations, map a expression to same shape ...
Definition: expression.h:50
#define MSHADOW_SCALAR_
Definition: tensor.h:1096
const SrcExp & src_
source operand
Definition: reduceto1d.h:45
binary map expression lhs [op] rhs
Definition: expression.h:334
DType scale_
source operand, scale of the
Definition: reduceto1d.h:47
const TB & rhs_
right operand
Definition: expression.h:229
reduction to 1 dimension tensor input: Tensor<Device,k>: ishape output: Tensor<Device,1> shape[0] = ishape[dimkeep];
Definition: reduceto1d.h:41
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:71
DType scale_
scale over result
Definition: expression.h:231
matrix multiplication expression dot(lhs[.T], rhs[.T])
Definition: expression.h:224
scalar expression
Definition: expression.h:95
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
const TA & lhs_
left operand
Definition: expression.h:227
overloaded + operator between half_t and bf16_t
Definition: base.h:334
DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > operator*(const DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > &lhs, MSHADOW_SCALAR_ rhs)
dot operator def
Definition: expr_scalar-inl.h:40