mxnet
transpose.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_EXTENSION_TRANSPOSE_H_
27 #define MSHADOW_EXTENSION_TRANSPOSE_H_
28 #include <algorithm>
29 #include "../extension.h"
30 namespace mshadow {
31 namespace expr {
43 template<typename SrcExp, typename DType, int dimsrc>
45  public MakeTensorExp<TransposeExExp<SrcExp, DType, dimsrc>,
46  SrcExp, dimsrc, DType> {
48  const SrcExp &src_;
50  Shape<dimsrc> dst_in_src_stride_; // Holds the corresponding stride of the dst axes in src
53  explicit TransposeExExp(const SrcExp &src, Shape<dimsrc> axes) : src_(src), axes_(axes) {
55  src_stride_ = src_shape[dimsrc - 1];
56  Shape<dimsrc> src_stride;
57  src_stride[dimsrc-1] = 1;
58  for (int i = dimsrc-2; i >= 0; --i) src_stride[i] = src_shape[i+1]*src_stride[i+1];
59  for (int i = 0; i < dimsrc; ++i) {
60  dst_in_src_stride_[i] = src_stride[axes[i]];
61  this->shape_[i] = src_shape[axes[i]];
62  }
63  }
64 };
75 template<typename SrcExp, typename DType, int etype>
79 }
80 
81 template<typename SrcExp, typename DType, int dimsrc>
82 struct Plan<TransposeExExp<SrcExp, DType, dimsrc>, DType> {
83  public:
85  : src_(MakePlan(e.src_)),
88  dst_shape_(e.shape_) {}
89  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
90  index_t idx = j * dst_in_src_stride_[dimsrc - 1];
91  #pragma unroll
92  for (int k = dimsrc-2; k >= 0; --k) {
93  idx += (i % dst_shape_[k]) * dst_in_src_stride_[k];
94  i /= dst_shape_[k];
95  }
96  return src_.Eval(idx/src_stride_, idx%src_stride_);
97  }
98 
99  private:
101  const index_t src_stride_;
102  const Shape<dimsrc> dst_in_src_stride_, dst_shape_;
103 };
104 
115 template<typename SrcExp, typename DType, int dimsrc, int etype>
117  public Exp<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType, etype> {
119  const SrcExp &src_indices_; // Expression of the source indices
120  Shape<dimsrc> src_shape_; // Holds the corresponding stride of the source axes in dst
121  const Shape<dimsrc> axes_; // The transpose axes
122  Shape<dimsrc> src_in_dst_stride_; // Holds the corresponding stride of the source axes in dst
124  explicit TransposeIndicesExp(const SrcExp &src_indices,
125  Shape<dimsrc> src_shape,
126  Shape<dimsrc> axes) : src_indices_(src_indices),
127  src_shape_(src_shape), axes_(axes) {
128  Shape<dimsrc> dst_shape_;
129  Shape<dimsrc> dst_stride_;
130  bool axes_checking_flag[dimsrc] = { 0 };
131  for (int i = 0; i < dimsrc; ++i) {
132  CHECK_LT(static_cast<int>(axes[i]), dimsrc)
133  << "Invalid axes input! All elements of axes must be between 0 and " << dimsrc
134  << ", find axes=" << axes;
135  dst_shape_[i] = src_shape[axes[i]];
136  axes_checking_flag[axes[i]] = true;
137  }
138  // check if the input axes is valid
139  for (int i = 0; i < dimsrc; ++i) {
140  CHECK_EQ(axes_checking_flag[i], true)
141  << "Invalid axes input! All elements of axes must be between 0 and " << dimsrc
142  << ", find axes=" << axes;
143  }
144  dst_stride_[dimsrc - 1] = 1;
145  for (int i = dimsrc - 2; i >= 0; --i) dst_stride_[i] = dst_shape_[i+1] * dst_stride_[i+1];
146  for (int i = 0; i < dimsrc; ++i) {
147  src_in_dst_stride_[axes[i]] = dst_stride_[i];
148  }
149  }
150 };
151 
162 template<typename SrcExp, typename DType, int dimsrc, int etype>
165  Shape<dimsrc> src_shape,
166  Shape<dimsrc> axes) {
167  return TransposeIndicesExp<SrcExp, DType, dimsrc, etype>(src_indices.self(), src_shape, axes);
168 }
169 
170 template<typename SrcExp, typename DType, int dimsrc, int etype>
171 struct Plan<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType> {
172  public:
174  : src_indices_(MakePlan(e.src_indices_)),
175  src_in_dst_stride_(e.src_in_dst_stride_),
176  src_shape_(e.src_shape_) {}
177  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
178  index_t src_idx = static_cast<index_t>(src_indices_.Eval(i, j));
179  index_t dst_idx = 0;
180  #pragma unroll
181  for (int k = dimsrc - 1; k >= 0; --k) {
182  dst_idx += (src_idx % src_shape_[k]) * src_in_dst_stride_[k];
183  src_idx /= src_shape_[k];
184  }
185  return static_cast<DType>(dst_idx);
186  }
187 
188  private:
189  Plan<SrcExp, DType> src_indices_;
190  const Shape<dimsrc> src_in_dst_stride_, src_shape_;
191 };
192 
193 //----------------------
194 // Execution plan
195 //----------------------
197 template<typename SrcExp, typename DType, int dimsrc, int etype>
200  return Plan<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType>(e);
201 }
202 
203 template<int dim, typename SrcExp, typename DType, int dimsrc, int etype>
204 struct ShapeCheck<dim, TransposeIndicesExp<SrcExp, DType, dimsrc, etype> > {
205  inline static Shape<dim>
208  return s;
209  }
210 };
211 
212 template<typename SrcExp, typename DType, int dimsrc, int etype>
213 struct ExpInfo<TransposeIndicesExp<SrcExp, DType, dimsrc, etype> > {
214  static const int kDim = ExpInfo<SrcExp>::kDim;
215  static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
216 };
217 } // namespace expr
218 } // namespace mshadow
219 #endif // MSHADOW_EXTENSION_TRANSPOSE_H_
TransposeExExp(const SrcExp &src, Shape< dimsrc > axes)
constructor
Definition: transpose.h:53
Definition: expr_engine-inl.h:59
static Shape< dim > Check(const TransposeIndicesExp< SrcExp, DType, dimsrc, etype > &t)
Definition: transpose.h:206
Plan(const TransposeExExp< SrcExp, DType, dimsrc > &e)
Definition: transpose.h:84
const Shape< dimsrc > axes_
Definition: transpose.h:49
transpose axes of a tensor input: Tensor<Device,dim>: ishape output: Tensor<Device,dimdst> oshape[a1],oshape[a2] = ishape[a2],oshape[a1]
Definition: transpose.h:44
index_t src_stride_
Definition: transpose.h:51
Shape< dimsrc > dst_in_src_stride_
Definition: transpose.h:50
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:223
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:263
const SrcExp & src_
source expression
Definition: transpose.h:48
int32_t index_t
type that will be used for index
Definition: base.h:336
Shape< dimsrc > src_in_dst_stride_
Definition: transpose.h:122
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: transpose.h:89
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:365
Plan(const TransposeIndicesExp< SrcExp, DType, dimsrc, etype > &e)
Definition: transpose.h:173
TransposeIndicesExp< SrcExp, DType, dimsrc, etype > transpose_indices(const Exp< SrcExp, DType, etype > &src_indices, Shape< dimsrc > src_shape, Shape< dimsrc > axes)
a expression that reshapes a tensor to another shape
Definition: transpose.h:164
transform contiguous indices of the source tensor to indices of the transposed tensor. input: Tensor<Device, k>: ishape output: Tensor<Device, k>: oshape = ishape
Definition: transpose.h:116
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const SubType & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
TransposeIndicesExp(const SrcExp &src_indices, Shape< dimsrc > src_shape, Shape< dimsrc > axes)
constructor
Definition: transpose.h:124
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:44
const Shape< dimsrc > axes_
Definition: transpose.h:121
overloaded + operator between half_t and bf16_t
Definition: base.h:327
Shape< dimsrc > src_shape_
Definition: transpose.h:120
TransposeExExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > transpose(const Exp< SrcExp, DType, etype > &src, Shape< ExpInfo< SrcExp >::kDim > axes)
a expression that reshapes a tensor to another shape
Definition: transpose.h:77
const SrcExp & src_indices_
source expression
Definition: transpose.h:119
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:48
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: transpose.h:177