25 #ifndef MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ 26 #define MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ 27 #include "../extension.h" 38 template<
typename SrcExp,
typename DType,
int srcdim>
40 public MakeTensorExp<UnpackPatchToColXExp<SrcExp, DType, srcdim>,
68 : img_(img), psize_y_(psize_y), psize_x_(psize_x),
69 pstride_y_(pstride_y), pstride_x_(pstride_x),
70 pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){
72 CHECK(imshape[srcdim - 1] >= psize_x && imshape[srcdim - 2] >= psize_y)
73 <<
"UnpackPatchToCol:image shape smaller than patch size";
74 this->i_channel_ = imshape[srcdim - 3];
75 this->i_height_ = imshape[srcdim - 2];
76 this->i_width_ = imshape[srcdim - 1];
79 const index_t o_height = (i_height_ -
80 (pdilate_y * (psize_y - 1) + 1)) / pstride_y + 1;
81 const index_t o_width = (i_width_ -
82 (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1;
83 this->
shape_[1] = o_height * o_width * num;
107 template<
typename SrcExp,
typename DType,
int etype>
112 ::Error_Expression_Does_Not_Meet_Dimension_Req();
114 (img.
self(), psize_y, psize_x, pstride, pstride, pdilate, pdilate);
120 template<
typename SrcExp,
typename DType,
int etype>
126 ::Error_Expression_Does_Not_Meet_Dimension_Req();
133 template<
typename SrcExp,
typename DType,
int srcdim>
150 const index_t jdivw = j / o_width_;
152 const index_t n = jdivw / o_height_;
169 #endif // MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ const SubType & self(void) const
Definition: expression.h:82
Definition: expr_engine-inl.h:58
used to help static type check
Definition: expr_engine-inl.h:330
UnpackPatchToColXExp(const SrcExp &img, index_t psize_y, index_t psize_x, index_t pstride_y, index_t pstride_x, index_t pdilate_y, index_t pdilate_x)
constructor
Definition: unpack_patch2col.h:61
index_t psize_x_
patch width
Definition: unpack_patch2col.h:47
unpack local (overlap) patches of image to column of mat, can be used to implement convolution...
Definition: unpack_patch2col.h:39
index_t pstride_y_
patch stride
Definition: unpack_patch2col.h:49
UnpackPatchToColXExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > unpack_patch2col(const Exp< SrcExp, DType, etype > &img, index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate)
unpack local (overlap) patches of image to column of mat, can be used to implement convolution after ...
Definition: unpack_patch2col.h:109
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:230
int32_t index_t
type that will be used for index
Definition: base.h:343
index_t i_height_
height of img
Definition: unpack_patch2col.h:57
index_t i_width_
width of img
Definition: unpack_patch2col.h:59
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: unpack_patch2col.h:144
index_t pdilate_x_
Definition: unpack_patch2col.h:53
const SrcExp & img_
source operand
Definition: unpack_patch2col.h:43
index_t i_channel_
number of input channel
Definition: unpack_patch2col.h:55
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const
Definition: tensor.h:157
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
Plan(const UnpackPatchToColXExp< SrcExp, DType, srcdim > &e)
Definition: unpack_patch2col.h:136
index_t pdilate_y_
patch dilate
Definition: unpack_patch2col.h:52
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
overloaded + operator between half_t and bf16_t
Definition: base.h:334
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
index_t pstride_x_
Definition: unpack_patch2col.h:50
index_t psize_y_
patch height
Definition: unpack_patch2col.h:45