26 #ifndef MSHADOW_EXTENSION_PACK_COL2PATCH_H_ 27 #define MSHADOW_EXTENSION_PACK_COL2PATCH_H_ 29 #include "../extension.h" 40 template<
typename SrcExp,
typename DType,
int dstdim>
42 public MakeTensorExp<PackColToPatchXExp<SrcExp, DType, dstdim>,
43 SrcExp, dstdim, DType> {
61 :src_(src), psize_y_(psize_y), psize_x_(psize_x),
62 pstride_y_(pstride_y), pstride_x_(pstride_x),
63 pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){
65 const index_t o_height = (imshape[dstdim - 2] -
66 (pdilate_y * (psize_y - 1)+ 1))/pstride_y + 1;
67 const index_t o_width = (imshape[dstdim - 1] -
68 (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1;
70 CHECK_EQ(sshape[1], o_height * o_width * imshape.
ProdShape(0, dstdim - 3))
71 <<
"PackColToPatchExp: src.size(1) mismatch";
72 CHECK_EQ(sshape[0], psize_y * psize_x * imshape[dstdim - 3])
73 <<
"PackColToPatchExp: src.size(0) mismatch";
89 template<
typename SrcExp,
typename DType,
int dstdim,
int etype>
95 ::Error_Expression_Does_Not_Meet_Dimension_Req();
96 CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
97 <<
"PackColToPatch:image shape smaller than patch size";
99 psize_y, psize_x, pstride, pstride,
105 template<
typename SrcExp,
typename DType,
int dstdim,
int etype>
112 ::Error_Expression_Does_Not_Meet_Dimension_Req();
113 CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
114 <<
"PackColToPatch:image shape smaller than patch size";
116 psize_y, psize_x, pstride_y, pstride_x,
117 pdilate_y, pdilate_x);
123 template<
typename SrcExp,
typename DType,
int dstdim>
130 i_height_(e.
shape_[dstdim - 2]),
139 const index_t y = i % i_height_;
140 const index_t idivh = i / i_height_;
141 const index_t c = idivh % i_channel_;
142 const index_t n = idivh / i_channel_;
154 DType res =
static_cast<DType
>(0);
159 (n * o_height_ + py) * o_width_ + px);
169 const index_t i_height_, o_height_, o_width_;
173 #endif // MSHADOW_EXTENSION_PACK_COL2PATCH_H_ PackColToPatchXExp< SrcExp, DType, dstdim > pack_col2patch(const expr::Exp< SrcExp, DType, etype > &src, Shape< dstdim > imshape, index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate)
reverse operation of pack_col2patch, can be used to implement deconvolution
Definition: pack_col2patch.h:91
index_t pdilate_x_
Definition: pack_col2patch.h:55
Definition: expr_engine-inl.h:59
used to help static type check
Definition: expr_engine-inl.h:331
shape of a tensor
Definition: tensor.h:54
const SrcExp & src_
source operand
Definition: pack_col2patch.h:45
Definition: optional.h:241
index_t pstride_x_
Definition: pack_col2patch.h:52
index_t psize_y_
patch height
Definition: pack_col2patch.h:47
reverse operation of UnpackPatchToCol, used to backprop gradient back this is a version supporting mu...
Definition: pack_col2patch.h:41
index_t pstride_y_
patch stride
Definition: pack_col2patch.h:51
index_t psize_x_
patch height
Definition: pack_col2patch.h:49
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:223
int32_t index_t
type that will be used for index
Definition: base.h:336
PackColToPatchXExp(const SrcExp &src, Shape< dstdim > imshape, index_t psize_y, index_t psize_x, index_t pstride_y, index_t pstride_x, index_t pdilate_y, index_t pdilate_x)
constructor
Definition: pack_col2patch.h:57
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const
Definition: tensor.h:158
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: pack_col2patch.h:137
index_t pdilate_y_
patch dilate
Definition: pack_col2patch.h:54
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const SubType & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:44
overloaded + operator between half_t and bf16_t
Definition: base.h:327
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:48
Plan(const PackColToPatchXExp< SrcExp, DType, dstdim > &e)
Definition: pack_col2patch.h:126