mxnet
expr_scalar-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
29 // macro guard is harmful, used to pass the cpplint
30 #ifndef MSHADOW_EXPR_SCALAR_INL_H_
31 #define MSHADOW_EXPR_SCALAR_INL_H_
32 // undef the guard so it can be included multiple times
33 #undef MSHADOW_EXPR_SCALAR_INL_H_
34 
35 namespace mshadow {
36 namespace expr {
37 // DotExp
39 template<typename TA, typename TB, bool ltrans, bool rtrans>
40 inline DotExp<TA, TB, ltrans, rtrans, MSHADOW_SCALAR_>
42  MSHADOW_SCALAR_ rhs) {
43  return DotExp<TA, TB, ltrans, rtrans,
44  MSHADOW_SCALAR_>(lhs.lhs_, lhs.rhs_, lhs.scale_ * rhs);
45 }
47 template<typename TA, typename TB, bool ltrans, bool rtrans>
51  return DotExp<TA, TB, ltrans, rtrans,
52  MSHADOW_SCALAR_>(rhs.lhs_, rhs.rhs_, rhs.scale_ * lhs);
53 }
54 
56 template<typename E, typename DType, typename R, int d>
59  return ReduceTo1DExp<E, DType, R, d>(e.src_, e.scale_ * scale);
60 }
62 template<typename E, typename DType, typename R, int d>
65  return ReduceTo1DExp<E, DType, R, d>(e.src_, e.scale_ * scale);
66 }
67 
69 template<typename OP, typename TA, int ta>
73  return MakeExp<OP>(lhs, rhs);
74 }
76 template<typename OP, typename TB, int tb>
80  return MakeExp<OP>(lhs, rhs);
81 }
83 template<typename OP>
84 inline BinaryMapExp<OP, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
86 F(const ScalarExp<MSHADOW_SCALAR_> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) {
87  return MakeExp<OP>(lhs, rhs);
88 }
89 // constant operators
91 template<typename TA, int ta>
94 operator+(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
95  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
96  return MakeExp<op::plus>(lhs, rhs);
97 }
99 template<typename TA, int ta>
102 operator-(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
103  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
104  return MakeExp<op::minus>(lhs, rhs);
105 }
107 template<typename TA, int ta>
110 operator*(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
111  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
112  return MakeExp<op::mul>(lhs, rhs);
113 }
115 template<typename TA, int ta>
118 operator/(const Exp<TA, MSHADOW_SCALAR_, ta> &lhs,
119  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
120  return MakeExp<op::div>(lhs, rhs);
121 }
122 // constant operators 2
124 template<typename TB, int tb>
127 operator+(const ScalarExp<MSHADOW_SCALAR_> &lhs,
128  const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
129  return MakeExp<op::plus>(lhs, rhs);
130 }
132 template<typename TB, int tb>
135 operator-(const ScalarExp<MSHADOW_SCALAR_> &lhs,
136  const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
137  return MakeExp<op::minus>(lhs, rhs);
138 }
140 template<typename TB, int tb>
143 operator*(const ScalarExp<MSHADOW_SCALAR_> &lhs,
144  const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
145  return MakeExp<op::mul>(lhs, rhs);
146 }
148 template<typename TB, int tb>
151 operator/(const ScalarExp<MSHADOW_SCALAR_> &lhs, const Exp<TB, MSHADOW_SCALAR_, tb> &rhs) {
152  return MakeExp<op::div>(lhs, rhs);
153 }
154 // constant operators 3
156 inline BinaryMapExp<op::plus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
158 operator+(const ScalarExp<MSHADOW_SCALAR_> &lhs,
159  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
160  return MakeExp<op::plus>(lhs, rhs);
161 }
163 inline BinaryMapExp<op::minus, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
165 operator-(const ScalarExp<MSHADOW_SCALAR_> &lhs,
166  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
167  return MakeExp<op::minus>(lhs, rhs);
168 }
170 inline BinaryMapExp<op::mul, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
172 operator*(const ScalarExp<MSHADOW_SCALAR_> &lhs,
173  const ScalarExp<MSHADOW_SCALAR_> &rhs) {
174  return MakeExp<op::mul>(lhs, rhs);
175 }
177 inline BinaryMapExp<op::div, ScalarExp<MSHADOW_SCALAR_>, ScalarExp<MSHADOW_SCALAR_>,
179 operator/(const ScalarExp<MSHADOW_SCALAR_> &lhs, const ScalarExp<MSHADOW_SCALAR_> &rhs) {
180  return MakeExp<op::div>(lhs, rhs);
181 }
182 } // namespace expr
183 } // namespace mshadow
184 #endif // MSHADOW_EXPR_SCALAR_INL_H_
const int kMapper
expression contains element-wise tensor operations, map a expression to same shape ...
Definition: expression.h:51
#define MSHADOW_SCALAR_
Definition: tensor.h:1097
const SrcExp & src_
source operand
Definition: reduceto1d.h:46
binary map expression lhs [op] rhs
Definition: expression.h:335
DType scale_
source operand, scale of the
Definition: reduceto1d.h:48
const TB & rhs_
right operand
Definition: expression.h:230
reduction to 1 dimension tensor input: Tensor<Device,k>: ishape output: Tensor<Device,1> shape[0] = ishape[dimkeep];
Definition: reduceto1d.h:42
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:72
DType scale_
scale over result
Definition: expression.h:232
matrix multiplication expression dot(lhs[.T], rhs[.T])
Definition: expression.h:225
scalar expression
Definition: expression.h:96
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const TA & lhs_
left operand
Definition: expression.h:228
overloaded + operator between half_t and bf16_t
Definition: base.h:327
DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > operator*(const DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > &lhs, MSHADOW_SCALAR_ rhs)
dot operator def
Definition: expr_scalar-inl.h:41