25 #ifndef MSHADOW_EXTENSION_PACK_COL2PATCH_H_ 26 #define MSHADOW_EXTENSION_PACK_COL2PATCH_H_ 28 #include "../extension.h" 39 template<
typename SrcExp,
typename DType,
int dstdim>
41 public MakeTensorExp<PackColToPatchXExp<SrcExp, DType, dstdim>,
42 SrcExp, dstdim, DType> {
60 :src_(src), psize_y_(psize_y), psize_x_(psize_x),
61 pstride_y_(pstride_y), pstride_x_(pstride_x),
62 pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){
64 const index_t o_height = (imshape[dstdim - 2] -
65 (pdilate_y * (psize_y - 1)+ 1))/pstride_y + 1;
66 const index_t o_width = (imshape[dstdim - 1] -
67 (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1;
69 CHECK_EQ(sshape[1], o_height * o_width * imshape.
ProdShape(0, dstdim - 3))
70 <<
"PackColToPatchExp: src.size(1) mismatch";
71 CHECK_EQ(sshape[0], psize_y * psize_x * imshape[dstdim - 3])
72 <<
"PackColToPatchExp: src.size(0) mismatch";
88 template<
typename SrcExp,
typename DType,
int dstdim,
int etype>
94 ::Error_Expression_Does_Not_Meet_Dimension_Req();
95 CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
96 <<
"PackColToPatch:image shape smaller than patch size";
98 psize_y, psize_x, pstride, pstride,
104 template<
typename SrcExp,
typename DType,
int dstdim,
int etype>
111 ::Error_Expression_Does_Not_Meet_Dimension_Req();
112 CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
113 <<
"PackColToPatch:image shape smaller than patch size";
115 psize_y, psize_x, pstride_y, pstride_x,
116 pdilate_y, pdilate_x);
122 template<
typename SrcExp,
typename DType,
int dstdim>
129 i_height_(e.
shape_[dstdim - 2]),
138 const index_t y = i % i_height_;
139 const index_t idivh = i / i_height_;
140 const index_t c = idivh % i_channel_;
141 const index_t n = idivh / i_channel_;
153 DType res =
static_cast<DType
>(0);
158 (n * o_height_ + py) * o_width_ + px);
168 const index_t i_height_, o_height_, o_width_;
172 #endif // MSHADOW_EXTENSION_PACK_COL2PATCH_H_ PackColToPatchXExp< SrcExp, DType, dstdim > pack_col2patch(const expr::Exp< SrcExp, DType, etype > &src, Shape< dstdim > imshape, index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate)
reverse operation of pack_col2patch, can be used to implement deconvolution
Definition: pack_col2patch.h:90
index_t pdilate_x_
Definition: pack_col2patch.h:54
const SubType & self(void) const
Definition: expression.h:82
Definition: expr_engine-inl.h:58
used to help static type check
Definition: expr_engine-inl.h:330
shape of a tensor
Definition: tensor.h:53
const SrcExp & src_
source operand
Definition: pack_col2patch.h:44
Definition: optional.h:251
index_t pstride_x_
Definition: pack_col2patch.h:51
index_t psize_y_
patch height
Definition: pack_col2patch.h:46
reverse operation of UnpackPatchToCol, used to backprop gradient back this is a version supporting mu...
Definition: pack_col2patch.h:40
index_t pstride_y_
patch stride
Definition: pack_col2patch.h:50
index_t psize_x_
patch height
Definition: pack_col2patch.h:48
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:230
int32_t index_t
type that will be used for index
Definition: base.h:343
PackColToPatchXExp(const SrcExp &src, Shape< dstdim > imshape, index_t psize_y, index_t psize_x, index_t pstride_y, index_t pstride_x, index_t pdilate_y, index_t pdilate_x)
constructor
Definition: pack_col2patch.h:56
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const
Definition: tensor.h:157
index_t pdilate_y_
patch dilate
Definition: pack_col2patch.h:53
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: pack_col2patch.h:136
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
overloaded + operator between half_t and bf16_t
Definition: base.h:334
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
Plan(const PackColToPatchXExp< SrcExp, DType, dstdim > &e)
Definition: pack_col2patch.h:125