mxnet
pack_col2patch.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_EXTENSION_PACK_COL2PATCH_H_
8 #define MSHADOW_EXTENSION_PACK_COL2PATCH_H_
9 #include <algorithm>
10 #include "../extension.h"
11 namespace mshadow {
12 namespace expr {
21 template<typename SrcExp, typename DType, int dstdim>
23  public MakeTensorExp<PackColToPatchXExp<SrcExp, DType, dstdim>,
24  SrcExp, dstdim, DType> {
26  const SrcExp &src_;
38  PackColToPatchXExp(const SrcExp &src, Shape<dstdim> imshape,
39  index_t psize_y, index_t psize_x,
40  index_t pstride_y, index_t pstride_x,
41  index_t pdilate_y, index_t pdilate_x)
42  :src_(src), psize_y_(psize_y), psize_x_(psize_x),
43  pstride_y_(pstride_y), pstride_x_(pstride_x),
44  pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){
45  this->shape_ = imshape;
46  const index_t o_height = (imshape[dstdim - 2] -
47  (pdilate_y * (psize_y - 1)+ 1))/pstride_y + 1;
48  const index_t o_width = (imshape[dstdim - 1] -
49  (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1;
51  CHECK_EQ(sshape[1], o_height * o_width * imshape.ProdShape(0, dstdim - 3))
52  << "PackColToPatchExp: src.size(1) mismatch";
53  CHECK_EQ(sshape[0], psize_y * psize_x * imshape[dstdim - 3])
54  << "PackColToPatchExp: src.size(0) mismatch";
55  }
56 };
70 template<typename SrcExp, typename DType, int dstdim, int etype>
73  Shape<dstdim> imshape, index_t psize_y,
74  index_t psize_x, index_t pstride, index_t pdilate) {
76  ::Error_Expression_Does_Not_Meet_Dimension_Req();
77  CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
78  << "PackColToPatch:image shape smaller than patch size";
79  return PackColToPatchXExp<SrcExp, DType, dstdim>(src.self(), imshape,
80  psize_y, psize_x, pstride, pstride,
81  pdilate, pdilate);
82 }
86 template<typename SrcExp, typename DType, int dstdim, int etype>
89  Shape<dstdim> imshape, index_t psize_y,
90  index_t psize_x, index_t pstride_y, index_t pstride_x,
91  index_t pdilate_y, index_t pdilate_x) {
93  ::Error_Expression_Does_Not_Meet_Dimension_Req();
94  CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
95  << "PackColToPatch:image shape smaller than patch size";
96  return PackColToPatchXExp<SrcExp, DType, dstdim>(src.self(), imshape,
97  psize_y, psize_x, pstride_y, pstride_x,
98  pdilate_y, pdilate_x);
99 }
100 
101 //----------------------
102 // Execution plan
103 //----------------------
104 template<typename SrcExp, typename DType, int dstdim>
105 struct Plan<PackColToPatchXExp<SrcExp, DType, dstdim>, DType> {
106  public:
110  i_channel_(e.shape_[dstdim - 3]), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_),
111  i_height_(e.shape_[dstdim - 2]),
112  o_height_((e.shape_[dstdim - 2] - (pdilate_y_ * (psize_y_ - 1) + 1)) /
113  pstride_y_ + 1),
114  o_width_((e.shape_[dstdim - 1] - (pdilate_x_ * (psize_x_ - 1) + 1)) /
115  pstride_x_ + 1) {
116  // note: i/o convention are same as unpack
117  }
118  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
119  using namespace std;
120  const index_t y = i % i_height_;
121  const index_t idivh = i / i_height_;
122  const index_t c = idivh % i_channel_;
123  const index_t n = idivh / i_channel_;
124  const index_t x = j;
125 
126  const index_t psize_y_dilate = (pdilate_y_ * (psize_y_ - 1) + 1);
127  const index_t psize_x_dilate = (pdilate_x_ * (psize_x_ - 1) + 1);
128 
129  const index_t py_min =
130  y < psize_y_dilate ? y % pdilate_y_ : (y-psize_y_dilate + pstride_y_) / pstride_y_;
131  const index_t px_min =
132  x < psize_x_dilate ? x % pdilate_x_ : (x-psize_x_dilate + pstride_x_) / pstride_x_;
133  const index_t py_max = min((y + pstride_y_) / pstride_y_, o_height_);
134  const index_t px_max = min((x + pstride_x_) / pstride_x_, o_width_);
135  DType res = static_cast<DType>(0);
136  for (index_t py = py_min; py < py_max; py += pdilate_y_) {
137  for (index_t px = px_min; px < px_max; px += pdilate_x_) {
138  res += src_.Eval(((c * psize_y_ + (y - py*pstride_y_) / pdilate_y_) * psize_x_ +
139  (x - px * pstride_x_) / pdilate_x_),
140  (n * o_height_ + py) * o_width_ + px);
141  }
142  }
143  return res;
144  }
145 
146  private:
148  const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_;
150  const index_t i_height_, o_height_, o_width_;
151 };
152 } // namespace expr
153 } // namespace mshadow
154 #endif // MSHADOW_EXTENSION_PACK_COL2PATCH_H_
PackColToPatchXExp< SrcExp, DType, dstdim > pack_col2patch(const expr::Exp< SrcExp, DType, etype > &src, Shape< dstdim > imshape, index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate)
reverse operation of pack_col2patch, can be used to implement deconvolution
Definition: pack_col2patch.h:72
index_t pdilate_x_
Definition: pack_col2patch.h:36
Definition: expr_engine-inl.h:40
used to help static type check
Definition: expr_engine-inl.h:312
shape of a tensor
Definition: tensor.h:35
const SrcExp & src_
source operand
Definition: pack_col2patch.h:26
Definition: optional.h:241
index_t pstride_x_
Definition: pack_col2patch.h:33
index_t psize_y_
patch height
Definition: pack_col2patch.h:28
reverse operation of UnpackPatchToCol, used to backprop gradient back this is a version supporting mu...
Definition: pack_col2patch.h:22
index_t pstride_y_
patch stride
Definition: pack_col2patch.h:32
index_t psize_x_
patch height
Definition: pack_col2patch.h:30
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:204
int32_t index_t
type that will be used for index
Definition: base.h:291
PackColToPatchXExp(const SrcExp &src, Shape< dstdim > imshape, index_t psize_y, index_t psize_x, index_t pstride_y, index_t pstride_x, index_t pdilate_y, index_t pdilate_x)
constructor
Definition: pack_col2patch.h:38
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const
Definition: tensor.h:139
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: pack_col2patch.h:118
index_t pdilate_y_
patch dilate
Definition: pack_col2patch.h:35
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const SubType & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:25
namespace for mshadow
Definition: base.h:282
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:29
Plan(const PackColToPatchXExp< SrcExp, DType, dstdim > &e)
Definition: pack_col2patch.h:107