mxnet
tensor_cpu-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_TENSOR_CPU_INL_H_
26 #define MSHADOW_TENSOR_CPU_INL_H_
27 #include <cstring>
28 #include <functional>
29 #include <utility>
30 #include <vector>
31 #include "./base.h"
32 #include "./tensor.h"
33 #include "./packet-inl.h"
34 #include "./dot_engine-inl.h"
35 
36 namespace mshadow {
37 template<>
38 inline void InitTensorEngine<cpu>(int dev_id) {
39 }
40 template<>
41 inline void ShutdownTensorEngine<cpu>(void) {
42 }
43 
44 template<>
45 inline void SetDevice<cpu>(int devid) {
46 }
47 template<>
48 inline Stream<cpu> *NewStream<cpu>(bool create_blas_handle,
49  bool create_dnn_handle,
50  int dev_id) {
51  return new Stream<cpu>();
52 }
53 template<>
54 inline void DeleteStream<cpu>(Stream<cpu> *stream) {
55  delete stream;
56 }
57 
58 template<int ndim>
59 inline std::ostream &operator<<(std::ostream &os, const Shape<ndim> &shape) { // NOLINT(*)
60  os << '(';
61  for (int i = 0; i < ndim; ++i) {
62  if (i != 0) os << ',';
63  os << shape[i];
64  }
65  // python style tuple
66  if (ndim == 1) os << ',';
67  os << ')';
68  return os;
69 }
70 
71 template<typename xpu>
72 inline void *AllocHost_(size_t size);
73 template<typename xpu>
74 inline void FreeHost_(void * dptr);
75 
76 #ifdef __CUDACC__
77 template<>
78 inline void *AllocHost_<gpu>(size_t size) {
79  void *dptr;
80  MSHADOW_CUDA_CALL(cudaMallocHost(&dptr, size, cudaHostAllocPortable));
81  return dptr;
82 }
83 template<>
84 inline void FreeHost_<gpu>(void *dptr) {
85  MSHADOW_CUDA_CALL(cudaFreeHost(dptr));
86 }
87 #endif
88 
89 template<>
90 inline void *AllocHost_<cpu>(size_t size) {
91  size_t pitch;
92  return packet::AlignedMallocPitch(&pitch, size, 1);
93 }
94 template<>
95 inline void FreeHost_<cpu>(void *dptr) {
96  packet::AlignedFree(dptr);
97 }
98 
99 template<typename xpu, int dim, typename DType>
101  obj->stride_ = obj->size(dim - 1);
102  CHECK_EQ(obj->CheckContiguous(), true) << "AllocHost";
103  void *dptr = AllocHost_<xpu>(obj->MSize() * sizeof(DType));
104  obj->dptr_ = reinterpret_cast<DType*>(dptr);
105 }
106 template<typename xpu, int dim, typename DType>
108  if (obj->dptr_ == NULL) {
109  LOG(FATAL) << "FreeHost:: double free";
110  }
111  FreeHost_<xpu>(obj->dptr_);
112  obj->dptr_ = NULL;
113 }
114 
115 template<int dim, typename DType>
116 inline void AllocSpace(Tensor<cpu, dim, DType> *obj, bool pad) {
117  size_t pitch;
118  void *dptr;
119  if (pad) {
121  (&pitch, obj->size(dim - 1) * sizeof(DType), obj->shape_.FlatTo2D()[0]);
122  obj->stride_ = static_cast<index_t>(pitch / sizeof(DType));
123  } else {
124  obj->stride_ = obj->size(dim - 1);
126  (&pitch, obj->shape_.Size() * sizeof(DType), 1);
127  }
128  obj->dptr_ = reinterpret_cast<DType*>(dptr);
129 }
130 template<typename Device, typename DType, int dim>
132 NewTensor(const Shape<dim> &shape, DType initv, bool pad, Stream<Device> *stream_) {
133  Tensor<Device, dim, DType> obj(shape);
134  obj.stream_ = stream_;
135  AllocSpace(&obj, pad);
136  MapExp<sv::saveto>(&obj, expr::ScalarExp<DType>(initv));
137  return obj;
138 }
139 template<int dim, typename DType>
142  obj->dptr_ = NULL;
143 }
144 template<int dim, typename DType>
145 inline void Copy(Tensor<cpu, dim, DType> _dst,
146  const Tensor<cpu, dim, DType> &_src,
147  Stream<cpu> *stream) {
148  CHECK_EQ(_dst.shape_, _src.shape_)
149  << "Copy:shape mismatch:" << _dst.shape_ << " vs " << _src.shape_;
150  if (_dst.CheckContiguous() && _src.CheckContiguous()) {
151  memcpy(_dst.dptr_, _src.dptr_, sizeof(DType) * _dst.shape_.Size());
152  } else {
153  Tensor<cpu, 2, DType> dst = _dst.FlatTo2D();
154  Tensor<cpu, 2, DType> src = _src.FlatTo2D();
155  for (index_t y = 0; y < dst.size(0); ++y) {
156  memcpy(dst[y].dptr_, src[y].dptr_, sizeof(DType) * dst.size(1));
157  }
158  }
159 }
160 
161 template<typename Saver, typename R, int dim,
162  typename DType, typename E>
164  const expr::Plan<E, DType> &plan) {
165  Shape<2> shape = expr::ShapeCheck<dim, R>::Check(dst->self()).FlatTo2D();
166  expr::Plan<R, DType> dplan = expr::MakePlan(dst->self());
167 #ifndef __CUDACC__
168  #pragma omp parallel for
169 #endif
170  // temp remove openmp, as default setting throttles CPU
171  for (openmp_index_t y = 0; y < shape[0]; ++y) {
172  for (index_t x = 0; x < shape[1]; ++x) {
173  // trust your compiler! -_- they will optimize it
174  Saver::template Save<DType>(dplan.REval(y, x), plan.Eval(y, x));
175  }
176  }
177 }
178 // code to handle SSE optimization
179 template<bool pass_check, typename Saver,
180  typename R, int dim,
181  typename DType, typename E, int etype>
183  inline static void Map(TRValue<R, cpu, dim, DType> *dst,
184  const expr::Exp<E, DType, etype> &exp) {
185  MapPlan<Saver>(dst, MakePlan(exp.self()));
186  }
187 };
188 
189 template<typename SV, int dim, typename DType, typename E, int etype>
190 struct MapExpCPUEngine<true, SV, Tensor<cpu, dim, DType>,
191  dim, DType, E, etype> {
192  inline static void Map(Tensor<cpu, dim, DType> *dst,
193  const expr::Exp<E, DType, etype> &exp) {
196  expr::MapPacketPlan<SV>(dst->self(),
197  expr::MakePacketPlan<MSHADOW_DEFAULT_PACKET>(exp.self()));
198  } else {
199  MapPlan<SV>(dst, MakePlan(exp.self()));
200  }
201  }
202 };
203 
204 
205 template<typename Saver, typename R, int dim,
206  typename DType, typename E, int etype>
208  const expr::Exp<E, DType, etype> &exp) {
210  ::Error_All_Tensor_in_Exp_Must_Have_Same_Type();
213  CHECK(eshape[0] == 0 || eshape == dshape)
214  << "Assignment: Shape of Tensors are not consistent with target, "
215  << "eshape: " << eshape << " dshape:" << dshape;
217  Saver, R, dim, DType, E, etype>
218  ::Map(dst->ptrself(), exp);
219 }
220 
221 template<typename Saver, typename Reducer,
222  typename R, typename DType, typename E, int etype>
224  const expr::Exp<E, DType, etype> &exp,
225  DType scale) {
227  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
229  ::Check(exp.self()).FlatTo2D();
231  CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match";
232  CHECK_NE(eshape[0], 0U) << "can not reduce over empty tensor";
233  // execution
234  expr::Plan<R, DType> dplan = MakePlan(dst->self());
235  expr::Plan<E, DType> splan = MakePlan(exp.self());
236 #ifndef __CUDACC__
237  #pragma omp parallel for
238 #endif
239  for (openmp_index_t x = 0; x < eshape[1]; ++x) {
240  DType res = splan.Eval(0, x);
241  for (index_t y = 1; y < eshape[0]; ++y) {
242  Reducer::Reduce(res, splan.Eval(y, x));
243  }
244  Saver::template Save<DType>(dplan.REval(0, x), res * scale);
245  }
246 }
247 
248 template<typename Saver, typename Reducer, int dimkeep,
249  typename R, typename DType, typename E, int etype>
251  const expr::Exp<E, DType, etype> &exp,
252  DType scale) {
254  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
255  typedef Shape<expr::ExpInfo<E>::kDim> EShape;
256  EShape eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
257  ::Check(exp.self());
259  CHECK_EQ(eshape[dimkeep], dshape[0])
260  << "MapReduceKeepHighDim::reduction dimension do not match";
261  // use equvalent form
262  Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep),
263  eshape[dimkeep],
264  eshape.ProdShape(dimkeep + 1, EShape::kSubdim),
265  eshape[EShape::kSubdim]);
266  // execution
267  expr::Plan<R, DType> dplan = MakePlan(dst->self());
268  expr::Plan<E, DType> splan = MakePlan(exp.self());
269 #ifndef __CUDACC__
270  #pragma omp parallel for
271 #endif
272  for (openmp_index_t c = 0; c < pshape[1]; ++c) {
273  DType res; Reducer::SetInitValue(res);
274  for (index_t n = 0; n < pshape[0]; ++n) {
275  DType tres; Reducer::SetInitValue(tres);
276  for (index_t y = 0; y < pshape[2]; ++y) {
277  for (index_t x = 0; x < pshape[3]; ++x) {
278  Reducer::Reduce(tres,
279  splan.Eval((n * pshape[1] + c) * pshape[2] + y, x));
280  }
281  }
282  Reducer::Reduce(res, tres);
283  }
284  Saver::template Save<DType>(dplan.REval(0, c), DType(res * scale));
285  }
286 }
287 
288 template<typename DType>
290  const Tensor<cpu, 1, DType> &energy) {
291  DType mmax = energy[0];
292  for (index_t x = 1; x < dst.size(0); ++x) {
293  if (mmax < energy[x]) mmax = energy[x];
294  }
295  DType sum = DType(0.0f);
296  for (index_t x = 0; x < dst.size(0); ++x) {
297  dst[x] = std::exp(energy[x] - mmax);
298  sum += dst[x];
299  }
300  for (index_t x = 0; x < dst.size(0); ++x) {
301  dst[x] /= sum;
302  }
303 }
304 
305 template<typename DType>
307  const Tensor<cpu, 2, DType> &src,
308  const Tensor<cpu, 1, DType> &label) {
309 #pragma omp parallel for
310  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
311  const index_t k = static_cast<int>(label[y]);
312  for (index_t x = 0; x < dst.size(1); ++x) {
313  if (x == k) {
314  dst[y][k] = src[y][k] - 1.0f;
315  } else {
316  dst[y][x] = src[y][x];
317  }
318  }
319  }
320 }
321 
322 template<typename DType>
324  const Tensor<cpu, 2, DType> &src,
325  const Tensor<cpu, 1, DType> &label,
326  const float alpha) {
327  const float smooth_grad = (alpha / (dst.size(1) - 1));
328 #pragma omp parallel for
329  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
330  const index_t k = static_cast<int>(label[y]);
331  for (index_t x = 0; x < dst.size(1); ++x) {
332  if (x == k) {
333  dst[y][k] = src[y][k] - 1.0f + alpha;
334  } else {
335  dst[y][x] = src[y][x] - smooth_grad;
336  }
337  }
338  }
339 }
340 
341 
342 template<typename DType>
344  const Tensor<cpu, 2, DType> &src,
345  const Tensor<cpu, 1, DType> &label,
346  const DType &ignore_label) {
347 #pragma omp parallel for
348  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
349  const int k = static_cast<int>(label[y]);
350  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
351  if (static_cast<int>(ignore_label) == k) {
352  dst[y][x] = 0.0f;
353  } else {
354  if (x == k) {
355  dst[y][k] = src[y][k] - 1.0f;
356  } else {
357  dst[y][x] = src[y][x];
358  }
359  }
360  }
361  }
362 }
363 
364 template<typename DType>
366  const Tensor<cpu, 2, DType> &src,
367  const Tensor<cpu, 1, DType> &label,
368  const DType &ignore_label,
369  const float alpha) {
370  const float smooth_grad = (alpha / (dst.size(1) - 1));
371 #pragma omp parallel for
372  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
373  const int k = static_cast<int>(label[y]);
374  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
375  if (static_cast<int>(ignore_label) == k) {
376  dst[y][x] = 0.0f;
377  } else {
378  if (x == k) {
379  dst[y][k] = src[y][k] - 1.0f + alpha;
380  } else {
381  dst[y][x] = src[y][x] - smooth_grad;
382  }
383  }
384  }
385  }
386 }
387 
388 template<typename DType>
390  const Tensor<cpu, 3, DType> &src,
391  const Tensor<cpu, 2, DType> &label) {
392 #pragma omp parallel for
393  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
394  for (index_t y = 0; y < dst.size(0); ++y) {
395  const int k = static_cast<int>(label[y][n]);
396  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
397  if (x == k) {
398  dst[y][k][n] = src[y][k][n] - 1.0f;
399  } else {
400  dst[y][x][n] = src[y][x][n];
401  }
402  }
403  }
404  }
405 }
406 
407 template<typename DType>
409  const Tensor<cpu, 3, DType> &src,
410  const Tensor<cpu, 2, DType> &label,
411  const float alpha) {
412  const float smooth_grad = (alpha / (dst.size(1) - 1));
413 #pragma omp parallel for
414  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
415  for (index_t y = 0; y < dst.size(0); ++y) {
416  const int k = static_cast<int>(label[y][n]);
417  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
418  if (x == k) {
419  dst[y][k][n] = src[y][k][n] - 1.0f + alpha;
420  } else {
421  dst[y][x][n] = src[y][x][n] - smooth_grad;
422  }
423  }
424  }
425  }
426 }
427 
428 template<typename DType>
430  const Tensor<cpu, 3, DType> &src,
431  const Tensor<cpu, 2, DType> &label,
432  const DType &ignore_label) {
433 #pragma omp parallel for
434  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
435  for (index_t y = 0; y < dst.size(0); ++y) {
436  const int k = static_cast<int>(label[y][n]);
437  if (k == static_cast<int>(ignore_label)) {
438  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
439  dst[y][x][n] = DType(0.0f);
440  }
441  } else {
442  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
443  if (x == k) {
444  dst[y][k][n] = src[y][k][n] - 1.0f;
445  } else {
446  dst[y][x][n] = src[y][x][n];
447  }
448  }
449  }
450  }
451  }
452 }
453 
454 template<typename DType>
456  const Tensor<cpu, 3, DType> &src,
457  const Tensor<cpu, 2, DType> &label,
458  const DType &ignore_label,
459  const float alpha) {
460  const float smooth_grad = (alpha / (dst.size(1) - 1));
461 #pragma omp parallel for
462  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
463  for (index_t y = 0; y < dst.size(0); ++y) {
464  const int k = static_cast<int>(label[y][n]);
465  if (k == static_cast<int>(ignore_label)) {
466  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
467  dst[y][x][n] = DType(0.0f);
468  }
469  } else {
470  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
471  if (x == k) {
472  dst[y][k][n] = src[y][k][n] - 1.0f + alpha;
473  } else {
474  dst[y][x][n] = src[y][x][n] - smooth_grad;
475  }
476  }
477  }
478  }
479  }
480 }
481 
482 template<typename DType>
484  const Tensor<cpu, 2, DType> &energy) {
485  CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch";
486 #pragma omp parallel for
487  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
488  Softmax(dst[y], energy[y]);
489  }
490 }
491 
492 template<typename DType>
494  const Tensor<cpu, 3, DType> &energy) {
495  CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch";
496 #pragma omp parallel for
497  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
498  for (index_t n = 0; n < dst.size(2); ++n) {
499  DType mmax = energy[y][0][n];
500  for (index_t x = 1; x < dst.size(1); ++x) {
501  if (mmax < energy[y][x][n]) mmax = energy[y][x][n];
502  }
503  DType sum = DType(0.0f);
504  for (index_t x = 0; x < dst.size(1); ++x) {
505  dst[y][x][n] = std::exp(energy[y][x][n] - mmax);
506  sum += dst[y][x][n];
507  }
508  for (index_t x = 0; x < dst.size(1); ++x) {
509  dst[y][x][n] /= sum;
510  }
511  }
512  }
513 }
514 
515 template<bool clip, typename IndexType, typename DType>
517  const Tensor<cpu, 1, IndexType>& index,
518  const Tensor<cpu, 2, DType> &src) {
519  const index_t K = dst.shape_[0];
520  const index_t C = dst.shape_[1];
521  for (index_t y = 0; y < index.size(0); ++y) {
522  index_t j = index[y];
523  if (clip) {
524  if (j <= 0) j = 0;
525  else if (j >= K) j = K - 1;
526  } else {
527  j %= K;
528  if (j < 0) j += K;
529  }
530  for (index_t i = 0; i < C; ++i) {
531  dst[j][i] += src[y][i];
532  }
533  }
534 }
535 
536 template<typename IndexType, typename DType>
538  const Tensor<cpu, 1, IndexType>& sorted,
539  const Tensor<cpu, 1, IndexType>& index,
540  const Tensor<cpu, 2, DType> &src) {
541  for (index_t y = 0; y < sorted.size(0); ++y) {
542  dst[sorted[y]] += src[index[y]];
543  }
544 }
545 
546 template<typename IndexType, typename DType>
548  const Tensor<cpu, 1, IndexType>& index,
549  const Tensor<cpu, 2, DType> &src) {
550  for (index_t y = 0; y < index.size(0); ++y) {
551  for (index_t j = 0; j < src.size(1); j++) {
552  dst[index[y]][j] = src[y][j];
553  }
554  }
555 }
556 
557 template<typename KDType, typename VDType>
559  bool is_ascend) {
560  CHECK_EQ(keys.CheckContiguous(), true);
561  CHECK_EQ(values.CheckContiguous(), true);
562  CHECK_EQ(keys.size(0), values.size(0))
563  << "The sizes of key/value are not equal! keys_size: " << keys.size(0)
564  << "values_size: " << values.size(0);
565  std::vector<size_t> idx(keys.size(0));
566  std::vector<KDType> keys_vec(keys.size(0));
567  std::vector<VDType> values_vec(values.size(0));
568  for (int i = 0; i < keys.size(0); i++) {
569  idx[i] = i;
570  keys_vec[i] = keys[i];
571  values_vec[i] = values[i];
572  }
573  if (is_ascend) {
574  std::stable_sort(idx.begin(), idx.end(),
575  [&keys_vec](size_t i1, size_t i2)
576  {return keys_vec[i1] < keys_vec[i2]; });
577  } else {
578  std::stable_sort(idx.begin(), idx.end(),
579  [&keys_vec](size_t i1, size_t i2)
580  {return keys_vec[i1] > keys_vec[i2]; });
581  }
582  for (index_t i = 0; i < values.size(0); i++) {
583  keys[i] = keys_vec[idx[i]];
584  values[i] = values_vec[idx[i]];
585  }
586 }
587 
588 template<typename Device, typename VDType, typename SDType>
590  // We can sort each segments using two stable sorts
591  SortByKey(values, segments, true);
592  SortByKey(segments, values, true);
593 }
594 
595 // blas related
596 template<typename Device, typename DType>
598  const Tensor<Device, 1, DType> &lhs,
599  const Tensor<Device, 1, DType> &rhs) {
600  CHECK_EQ(lhs.size(0), rhs.size(0))
601  << "VectorDot: Shape mismatch";
602  CHECK_EQ(dst.size(0), 1U)
603  << "VectorDot: expect dst to be scalar";
606  lhs.stream_, lhs.size(0), lhs.dptr_, 1, rhs.dptr_, 1, dst.dptr_);
607 }
608 
609 template<bool transpose_left, bool transpose_right, typename Device, typename DType>
611  const Tensor<Device, 3, DType> &lhs,
612  const Tensor<Device, 3, DType> &rhs,
613  DType alpha,
614  DType beta,
615  Tensor<Device, 1, DType*> workspace) {
616  index_t batch_size = dst.shape_[0];
618  Shape<3> sleft = transpose_left ? Shape3(lhs.shape_[0], lhs.shape_[2], lhs.shape_[1])
619  : lhs.shape_;
620  Shape<3> sright = transpose_right ? Shape3(rhs.shape_[0], rhs.shape_[2], rhs.shape_[1])
621  : rhs.shape_;
622  CHECK_EQ(dst.CheckContiguous(), true);
623  CHECK_EQ(lhs.CheckContiguous(), true);
624  CHECK_EQ(rhs.CheckContiguous(), true);
625  CHECK(sleft[0] == batch_size && sright[0] == batch_size)
626  << "BatchGEMM: batchsize must be equal."
627  << "dst: " << dst.shape_ << "\n"
628  << "lhs: " << sleft << "\n"
629  << "rhs: " << sright << "\n";
630  CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1])
631  << "BatchGEMM: matrix shape mismatch"
632  << "dst: " << dst.shape_ << "\n"
633  << "lhs: " << sleft << "\n"
634  << "rhs: " << sright << "\n";
635  CHECK(workspace.size(0) >= 3 * batch_size)
636  << "Workspace Size must be bigger than " << 3 * batch_size;
637  CHECK_EQ(workspace.CheckContiguous(), true);
638  // use column major argument to compatible with most BLAS
640  (dst.stream_,
641  transpose_right, transpose_left,
642  transpose_right ? rhs.size(1) : rhs.size(2),
643  transpose_left ? lhs.size(2) : lhs.size(1),
644  transpose_right ? rhs.size(2) : rhs.size(1),
645  alpha,
646  rhs.dptr_, rhs.stride_,
647  lhs.dptr_, lhs.stride_,
648  beta,
649  dst.dptr_, dst.stride_, batch_size,
650  workspace.dptr_);
651 }
652 } // namespace mshadow
653 #endif // MSHADOW_TENSOR_CPU_INL_H_
void VectorDot(Tensor< Device, 1, DType > dst, const Tensor< Device, 1, DType > &lhs, const Tensor< Device, 1, DType > &rhs)
CPU/GPU: 1 dimension vector dot.
Definition: tensor_cpu-inl.h:597
static void batched_gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc, int batch_count, DType **workspace)
Definition: dot_engine-inl.h:91
void FreeSpace(Tensor< cpu, dim, DType > *obj)
CPU/GPU: free the space of tensor, will set obj.dptr to NULL.
Definition: tensor_cpu-inl.h:140
void ShutdownTensorEngine< cpu >(void)
Definition: tensor_cpu-inl.h:41
Stream< Device > * stream_
Definition: tensor.h:574
void IndexFill(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix...
Definition: tensor_cpu-inl.h:547
void SoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label)
CPU/GPU: softmax gradient.
Definition: tensor_cpu-inl.h:306
void SmoothSoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label, const float alpha)
Definition: tensor_cpu-inl.h:323
PaddingExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, DType, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1]
Definition: pad.h:71
void BatchGEMM(Tensor< Device, 3, DType > dst, const Tensor< Device, 3, DType > &lhs, const Tensor< Device, 3, DType > &rhs, DType alpha, DType beta, Tensor< Device, 1, DType *> workspace)
CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst.
Definition: tensor_cpu-inl.h:610
DType * dptr_
pointer to the data
Definition: tensor.h:434
void FreeHost_(void *dptr)
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:409
const Container & self(void) const
Definition: expression.h:82
Definition: expr_engine-inl.h:58
void SetDevice< cpu >(int devid)
Definition: tensor_cpu-inl.h:45
used to help static type check
Definition: expr_engine-inl.h:330
void AlignedFree(void *ptr)
free aligned space
Definition: packet-inl.h:106
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:145
MSHADOW_XINLINE index_t MSize(void) const
Definition: tensor.h:497
void MapExp(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
CPU/GPU: map a expression to a tensor, this function calls MapPlan.
Definition: tensor_cpu-inl.h:207
Container * ptrself(void)
Definition: expression.h:86
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:436
Definition: packet-inl.h:379
MSHADOW_XINLINE Shape< 4 > Shape4(index_t s0, index_t s1, index_t s2, index_t s3)
construct a four dimension shape, stride will equal s0
Definition: tensor.h:240
void SortByKey(Tensor< cpu, 1, KDType > keys, Tensor< cpu, 1, VDType > values, bool is_ascend=true)
CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!) ...
Definition: tensor_cpu-inl.h:558
void Softmax(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &energy)
CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) ...
Definition: tensor_cpu-inl.h:483
void VectorizedSort(Tensor< Device, 1, VDType > values, Tensor< Device, 1, SDType > segments)
CPU/GPU: Sort the keys within each segment. (Stable sort is performed!) Segments is defined as an asc...
Definition: tensor_cpu-inl.h:589
void * AlignedMallocPitch(size_t *out_pitch, size_t lspace, size_t num_line)
analog to cudaMallocPitch, allocate a aligned space with num_line * lspace cells
Definition: packet-inl.h:77
#define MSHADOW_CUDA_CALL(func)
Protected cuda call in mshadow.
Definition: base.h:278
void MapReduceKeepLowest(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) ...
Definition: tensor_cpu-inl.h:223
static Shape< dim > Check(const E &t)
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
device name CPU
Definition: tensor.h:39
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
evaluate the expression at index [y][x] to be implemented by SubType, for RValue, the return type wil...
void * AllocHost_(size_t size)
MSHADOW_XINLINE index_t size(index_t i) const
Definition: tensor.h:606
void FreeHost_< cpu >(void *dptr)
Definition: tensor_cpu-inl.h:95
int32_t index_t
type that will be used for index
Definition: base.h:343
void AllocSpace(Tensor< cpu, dim, DType > *obj, bool pad=MSHADOW_ALLOC_PAD)
CPU/CPU: allocate space for CTensor, according to the shape in the obj this function is responsible t...
Definition: tensor_cpu-inl.h:116
DType * dptr_
Definition: tensor.h:571
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:519
Generic packet vectorization code.
void InitTensorEngine< cpu >(int dev_id)
Definition: tensor_cpu-inl.h:38
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:505
void AddTakeGradLargeBatch(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &sorted, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[sorted[i]] += src[index[i]] Called when the bat...
Definition: tensor_cpu-inl.h:537
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:491
static void dot(Stream< Device > *stream, int n, const DType *X, int incX, const DType *Y, int incY, DType *ret)
Definition: dot_engine-inl.h:125
void AllocHost(Tensor< cpu, dim, DType > *obj)
Definition: tensor_cpu-inl.h:100
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:364
Stream< cpu > * NewStream< cpu >(bool create_blas_handle, bool create_dnn_handle, int dev_id)
Definition: tensor_cpu-inl.h:48
void MapPlan(TRValue< R, cpu, dim, DType > *dst, const expr::Plan< E, DType > &plan)
Definition: tensor_cpu-inl.h:163
Definition: tensor_cpu-inl.h:182
scalar expression
Definition: expression.h:95
void MapReduceKeepHighDim(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) ...
Definition: tensor_cpu-inl.h:250
void * AllocHost_< cpu >(size_t size)
Definition: tensor_cpu-inl.h:90
Tensor< Device, dim, DType > NewTensor(const Shape< dim > &shape, DType initv, bool pad=MSHADOW_ALLOC_PAD, Stream< Device > *stream=NULL)
CPU/GPU: short cut to allocate and initialize a Tensor.
Definition: tensor_cpu-inl.h:132
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
void AddTakeGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[index[i]] += src[i] Called when the featuredim ...
Definition: tensor_cpu-inl.h:516
Definition: tensor.h:568
MSHADOW_XINLINE Shape< 3 > Shape3(index_t s0, index_t s1, index_t s2)
construct a three dimension shape, stride will equal s0
Definition: tensor.h:227
overloaded + operator between half_t and bf16_t
Definition: base.h:334
void FreeHost(Tensor< cpu, dim, DType > *obj)
Definition: tensor_cpu-inl.h:107
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:441
#define MSHADOW_DEFAULT_PACKET
Definition: packet-inl.h:47
general tensor
Definition: tensor.h:420
static void SetStream(Stream< Device > *stream)
Definition: dot_engine-inl.h:82
void DeleteStream< cpu >(Stream< cpu > *stream)
Definition: tensor_cpu-inl.h:54
static void Map(Tensor< cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
Definition: tensor_cpu-inl.h:192
index_t openmp_index_t
openmp index for linux
Definition: base.h:351
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation ...
Definition: tensor.h:446
definitions of how Matrix Multiplications can be evaluated
static void Map(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
Definition: tensor_cpu-inl.h:183
computaion stream structure, used for asynchronous computations
Definition: tensor.h:383