25 #ifndef MSHADOW_EXTENSION_COMPLEX_H_ 26 #define MSHADOW_EXTENSION_COMPLEX_H_ 28 #include "../extension.h" 37 template<
typename DType>
39 DType b_real, DType b_imag) {
40 return a_real * b_real - a_imag * b_imag;
42 template<
typename DType>
44 DType b_real, DType b_imag) {
45 return a_real * b_imag + b_real * a_imag;
51 template<
typename DType>
53 DType b_real, DType b_imag) {
54 return (a_real * b_real + a_imag * b_imag) / (b_real * b_real + b_imag * b_imag);
56 template<
typename DType>
58 DType b_real, DType b_imag) {
59 return (b_real * a_imag - a_real * b_imag) / (b_real * b_real + b_imag * b_imag);
64 template<
typename TA,
typename DType>
67 return src_.
Eval(real_i, real_j);
69 template<
typename TA,
typename DType>
72 return -src_.
Eval(imag_i, imag_j);
77 template<
typename TA,
typename DType>
80 return src_.
Eval(imag_i, imag_j);
82 template<
typename TA,
typename DType>
85 return src_.
Eval(real_i, real_j);
91 template<
typename TA,
typename DType>
94 return src_.
Eval(real_i, real_j);
96 template<
typename TA,
typename DType>
105 template<
typename TA,
typename DType>
108 DType real_val = src_.
Eval(real_i, real_j);
114 template<
typename TA,
typename DType>
117 DType real_val = src_.
Eval(real_i, real_j);
118 DType image_val = src_.
Eval(imag_i, imag_j);
119 return real_val * real_val + image_val * image_val;
124 template<
typename TA,
typename DType>
127 DType real_val = src_.
Eval(real_i, real_j);
128 DType image_val = src_.
Eval(imag_i, imag_j);
129 return real_val + image_val;
147 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
156 :lhs_(lhs), rhs_(rhs) {}
167 template<
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
178 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int ta,
int tb>
190 template<
int calctype,
typename OP,
typename SrcExp,
typename DType,
int e1>
199 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
203 return ComplexF<op::complex::kBinaryCC, op::complex::mul>(lhs, rhs);
209 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
213 return ComplexF<op::complex::kBinaryCR, op::complex::mul>(lhs, rhs);
219 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
223 return ComplexF<op::complex::kBinaryRC, op::complex::mul>(lhs, rhs);
229 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
233 return ComplexF<op::complex::kBinaryCC, op::complex::div>(lhs, rhs);
239 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
243 return ComplexF<op::complex::kBinaryCR, op::complex::div>(lhs, rhs);
249 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
253 return ComplexF<op::complex::kBinaryRC, op::complex::div>(lhs, rhs);
261 template<
typename SrcExp,
typename DType,
int e1>
265 return ComplexF<op::complex::kUnitaryC2C, op::complex::conjugate>(src);
273 template<
typename SrcExp,
typename DType,
int e1>
277 return ComplexF<op::complex::kUnitaryC2C, op::complex::exchange>(src);
285 template<
typename SrcExp,
typename DType,
int e1>
289 return ComplexF<op::complex::kUnitaryR2C, op::complex::pad_imag>(src);
297 template<
typename SrcExp,
typename DType,
int e1>
301 return ComplexF<op::complex::kUnitaryC2R, op::complex::toreal>(src);
309 template<
typename SrcExp,
typename DType,
int e1>
313 return ComplexF<op::complex::kUnitaryC2R, op::complex::abs_square>(src);
316 template<
typename SrcExp,
typename DType,
int e1>
320 return ComplexF<op::complex::kUnitaryC2R, op::complex::sum_real_imag>(src);
323 template<
int dim,
int calctype,
typename OP,
typename TA,
typename TB,
324 typename DType,
int etype>
330 if (shape1[0] == 0)
return shape2;
331 if (shape2[0] == 0)
return shape1;
332 if (calctype == op::complex::kBinaryCC) {
333 CHECK_EQ(shape1, shape2) <<
"ComplexBinaryMapExp (CC): Shapes of operands are not the same.";
334 CHECK_EQ(shape1[dim - 1] % 2, 0) <<
335 "ComplexBinaryMapExp (CC): Shape of the last dimension is not even. " 336 "We must have real part + imaginary part.";
338 }
else if (calctype == op::complex::kBinaryCR) {
339 for (
int i = 0; i < dim - 1; ++i) {
341 "ComplexBinaryMapExp (CR): Shapes of operands are not the same.";
343 CHECK_EQ(shape1[dim - 1], shape2[dim - 1] * 2) <<
344 "ComplexBinaryMapExp (CR): Shapes of operands do not match.";
346 }
else if (calctype == op::complex::kBinaryRC) {
347 for (
int i = 0; i < dim - 1; ++i) {
349 "ComplexBinaryMapExp (RC): Shapes of operands are not the same.";
351 CHECK_EQ(shape2[dim - 1], shape1[dim - 1] * 2) <<
352 "ComplexBinaryMapExp (RC): Shapes of operands do not match.";
355 LOG(FATAL) <<
"ComplexBinaryMapExp: Unexpected Calculation Type!";
361 template<
int dim,
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
365 CHECK_EQ(s[dim - 1] % 2, 0) <<
"ComplexUnitaryExp: Shape of the last dimension is not even. " 366 "We must have real + imaginary.";
367 if (calctype == op::complex::kUnitaryC2C) {
369 }
else if (calctype == op::complex::kUnitaryC2R) {
373 }
else if (calctype == op::complex::kUnitaryR2C) {
378 LOG(FATAL) <<
"ComplexUnitaryExp: Unexpected Calculation Type!";
387 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
391 : lhs_(lhs), rhs_(rhs) {}
395 return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
396 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
398 return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
399 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
409 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
413 : lhs_(lhs), rhs_(rhs) {}
417 return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
418 rhs_.Eval(y, base_x / 2),
static_cast<DType
>(0));
420 return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1),
421 rhs_.Eval(y, base_x / 2),
static_cast<DType
>(0));
432 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
436 : lhs_(lhs), rhs_(rhs) {}
440 return OP::RealMap(lhs_.Eval(y, base_x / 2),
static_cast<DType
>(0),
441 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
443 return OP::ImagMap(lhs_.Eval(y, base_x / 2),
static_cast<DType
>(0),
444 rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1));
455 template<
typename OP,
typename TA,
int etype,
typename DType>
462 return OP::RealMap(src_, y, base_x, y, base_x + 1);
464 return OP::ImagMap(src_, y, base_x, y, base_x + 1);
473 template<
typename OP,
typename TA,
int etype,
typename DType>
483 return OP::RealMap(src_, y, real_x);
485 return OP::ImagMap(src_, y, real_x);
494 template<
typename OP,
typename TA,
int etype,
typename DType>
499 return OP::RealMap(src_, y, x * 2, y, x * 2 + 1);
508 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
511 return Plan<ComplexBinaryMapExp<calctype, OP, TA, TB, DType, etype>,
515 template<
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
518 return Plan<ComplexUnitaryExp<calctype, OP, TA, DType, etype>,
524 template<
int calctype,
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
528 static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ? \
531 ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1;
535 template<
int calctype,
typename OP,
typename TA,
typename DType,
int etype>
543 #endif // MSHADOW_EXTENSION_COMPLEX_H_ ComplexBinaryMapExp< op::complex::kBinaryRC, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_rc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_rc Complex multipilication of a real tensor B and a complex tensor A
Definition: complex.h:222
static MSHADOW_XINLINE DType RealMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
map a_real, a_imag, b_real, b_imag to result using defined operation
Definition: complex.h:52
Plan< ComplexUnitaryExp< calctype, OP, TA, DType, etype >, DType > MakePlan(const ComplexUnitaryExp< calctype, OP, TA, DType, etype > &e)
Definition: complex.h:517
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:83
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:70
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:437
const int kMapper
expression contains element-wise tensor operations, map a expression to same shape ...
Definition: expression.h:50
ComplexUnitaryExp(const TA &src)
constructor
Definition: complex.h:173
const SubType & self(void) const
Definition: expression.h:82
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::abs_square, SrcExp, DType,(e1|type::kMapper)> complex_abs_square(const Exp< SrcExp, DType, e1 > &src)
complex_abs_square calculate the square of the modulus of A where A is a complex tensor ...
Definition: complex.h:312
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:125
UnitaryCalculationType
Definition: complex.h:34
const TB & rhs_
right operand
Definition: complex.h:153
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:392
ComplexUnitaryExp< calctype, OP, SrcExp, DType,(e1|type::kMapper)> ComplexF(const Exp< SrcExp, DType, e1 > &src)
conj Negation the imaginary part of A where A is a complex tensor
Definition: complex.h:192
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:106
ComplexUnitaryExp< op::complex::kUnitaryR2C, op::complex::pad_imag, SrcExp, DType,(e1|type::kMapper)> complex_pad_imag(const Exp< SrcExp, DType, e1 > &src)
complex_pad_imag Transform real matrix into complex matrix
Definition: complex.h:288
static MSHADOW_XINLINE DType ImagMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j)
Definition: complex.h:97
static Shape< dim > Check(const ComplexBinaryMapExp< calctype, OP, TA, TB, DType, etype > &t)
Definition: complex.h:327
BinaryCalculationType
Definition: complex.h:33
Definition: complex.h:123
const TA & lhs_
left operand
Definition: complex.h:151
ComplexBinaryMapExp< op::complex::kBinaryCC, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_cc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cc Complex multipilication two complex tensors, A * B
Definition: complex.h:202
ComplexUnitaryExp< op::complex::kUnitaryC2C, op::complex::conjugate, SrcExp, DType,(e1|type::kMapper)> conj(const Exp< SrcExp, DType, e1 > &src)
conj Negation the imaginary part of A where A is a complex tensor
Definition: complex.h:264
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
evaluate the expression at index [y][x] to be implemented by SubType, for RValue, the return type wil...
#define MSHADOW_XINLINE
Definition: base.h:230
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:262
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:459
const TA & src_
source expression
Definition: complex.h:171
int32_t index_t
type that will be used for index
Definition: base.h:343
static Shape< dim > Check(const ComplexUnitaryExp< calctype, OP, TA, DType, etype > &t)
Definition: complex.h:363
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j)
Definition: complex.h:92
ComplexBinaryMapExp< op::complex::kBinaryCC, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_cc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cc Complex multipilication two complex tensors, A * B
Definition: complex.h:232
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:477
static MSHADOW_XINLINE DType RealMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
map a_real, a_imag, b_real, b_imag to result using defined operation
Definition: complex.h:38
Plan(const Plan< TA, DType > &src)
Definition: complex.h:497
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:435
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:75
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:364
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:65
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:115
ComplexUnitaryExp< op::complex::kUnitaryC2C, op::complex::exchange, SrcExp, DType,(e1|type::kMapper)> complex_exchange(const Exp< SrcExp, DType, e1 > &src)
complex_exchange Exchange the real and imaginary part of A where A is a complex tensor ...
Definition: complex.h:276
static MSHADOW_XINLINE DType ImagMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
Definition: complex.h:57
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:390
ComplexBinaryMapExp< op::complex::kBinaryCR, op::complex::mul, TA, TB, DType,(ta|tb|type::kMapper)> complex_mul_cr(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
Definition: complex.h:212
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: complex.h:412
Definition: complex.h:104
ComplexBinaryMapExp(const TA &lhs, const TB &rhs)
constructor
Definition: complex.h:155
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::sum_real_imag, SrcExp, DType,(e1|type::kMapper)> complex_sum_real_imag(const Exp< SrcExp, DType, e1 > &src)
Definition: complex.h:319
static MSHADOW_XINLINE DType ImagMap(DType a_real, DType a_imag, DType b_real, DType b_imag)
Definition: complex.h:43
ComplexUnitaryExp< op::complex::kUnitaryC2R, op::complex::toreal, SrcExp, DType,(e1|type::kMapper)> complex_toreal(const Exp< SrcExp, DType, e1 > &src)
complex_toreal convert complex matrix to real matrix, keep only real part
Definition: complex.h:300
overloaded + operator between half_t and bf16_t
Definition: base.h:334
Plan(const Plan< TA, DType > &src)
Definition: complex.h:476
Plan(const Plan< TA, DType > &src)
Definition: complex.h:458
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:414
static MSHADOW_XINLINE DType RealMap(const expr::Plan< TA, DType > &src_, index_t real_i, index_t real_j, index_t imag_i, index_t imag_j)
Definition: complex.h:78
ComplexBinaryMapExp< op::complex::kBinaryRC, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_rc(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_rc Complex multipilication of a real tensor A and a complex tensor B
Definition: complex.h:252
Definition: complex.h:113
compute conj(src) where src is a complex tensor
Definition: complex.h:168
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: complex.h:498
ComplexBinaryMapExp< op::complex::kBinaryCR, op::complex::div, TA, TB, DType,(ta|tb|type::kMapper)> complex_div_cr(const Exp< TA, DType, ta > &lhs, const Exp< TB, DType, tb > &rhs)
complex_mul_cr Complex multipilication a complex tensor A and a real tensor B
Definition: complex.h:242
binary map expression lhs [op] rhs where lhs and rhs are complex tensors
Definition: complex.h:148