Source code for mxnet.gluon.data.dataset

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=
"""Dataset container."""
__all__ = ['Dataset', 'SimpleDataset', 'ArrayDataset',
           'RecordFileDataset']

import os

from ... import recordio, ndarray
from ...util import default_array


[docs]class Dataset(object): """Abstract dataset class. All datasets should have this interface. Subclasses need to override `__getitem__`, which returns the i-th element, and `__len__`, which returns the total number elements. .. note:: An mxnet or numpy array can be directly used as a dataset. """ def __getitem__(self, idx): raise NotImplementedError def __len__(self): raise NotImplementedError
[docs] def filter(self, fn): """Returns a new dataset with samples filtered by the filter function `fn`. Note that if the Dataset is the result of a lazily transformed one with transform(lazy=False), the filter is eagerly applied to the transformed samples without materializing the transformed result. That is, the transformation will be applied again whenever a sample is retrieved after filter(). Parameters ---------- fn : callable A filter function that takes a sample as input and returns a boolean. Samples that return False are discarded. Returns ------- Dataset The filtered dataset. """ from . import FilterSampler return _SampledDataset(self, FilterSampler(fn, self))
[docs] def shard(self, num_shards, index): """Returns a new dataset includes only 1/num_shards of this dataset. For distributed training, be sure to shard before you randomize the dataset (such as shuffle), if you want each worker to reach a unique subset. Parameters ---------- num_shards : int A integer representing the number of data shards. index : int A integer representing the index of the current shard. Returns ------- Dataset The result dataset. """ assert index < num_shards, f'Shard index of out bound: {index} out of {num_shards}' assert num_shards > 0, 'Number of shards must be greater than 0' assert index >= 0, 'Index must be non-negative' length = len(self) shard_len = length // num_shards rest = length % num_shards # Compute the start index for this partition start = shard_len * index + min(index, rest) # Compute the end index for this partition end = start + shard_len + (index < rest) from . import SequentialSampler return _SampledDataset(self, SequentialSampler(end - start, start))
[docs] def take(self, count): """Returns a new dataset with at most `count` number of samples in it. Parameters ---------- count : int or None A integer representing the number of elements of this dataset that should be taken to form the new dataset. If count is None, or if count is greater than the size of this dataset, the new dataset will contain all elements of this dataset. Returns ------- Dataset The result dataset. """ if count is None or count > len(self): count = len(self) from . import SequentialSampler return _SampledDataset(self, SequentialSampler(count))
[docs] def sample(self, sampler): """Returns a new dataset with elements sampled by the sampler. Parameters ---------- sampler : Sampler A Sampler that returns the indices of sampled elements. Returns ------- Dataset The result dataset. """ from . import Sampler if not isinstance(sampler, Sampler): raise TypeError(f'Invalid sampler type: {type(sampler)}. Expected gluon.data.Sampler instead.') return _SampledDataset(self, sampler)
[docs] def transform(self, fn, lazy=True): """Returns a new dataset with each sample transformed by the transformer function `fn`. Parameters ---------- fn : callable A transformer function that takes a sample as input and returns the transformed sample. lazy : bool, default True If False, transforms all samples at once. Otherwise, transforms each sample on demand. Note that if `fn` is stochastic, you must set lazy to True or you will get the same result on all epochs. Returns ------- Dataset The transformed dataset. """ trans = _LazyTransformDataset(self, fn) if lazy: return trans return SimpleDataset([i for i in trans])
[docs] def transform_first(self, fn, lazy=True): """Returns a new dataset with the first element of each sample transformed by the transformer function `fn`. This is mostly applicable when each sample contains two components - features and label, i.e., (X, y), and you only want to transform the first element X (i.e., the features) while keeping the label y unchanged. Parameters ---------- fn : callable A transformer function that takes the first element of a sample as input and returns the transformed element. lazy : bool, default True If False, transforms all samples at once. Otherwise, transforms each sample on demand. Note that if `fn` is stochastic, you must set lazy to True or you will get the same result on all epochs. Returns ------- Dataset The transformed dataset. """ return self.transform(_TransformFirstClosure(fn), lazy)
[docs]class SimpleDataset(Dataset): """Simple Dataset wrapper for lists and arrays. Parameters ---------- data : dataset-like object Any object that implements `len()` and `[]`. """ def __init__(self, data): self._data = data self._handle = None def __len__(self): return len(self._data) def __getitem__(self, idx): return self._data[idx] def __mx_handle__(self): if self._handle is None: import numpy as np from ._internal import NDArrayDataset if isinstance(self._data, (np.ndarray, ndarray.NDArray)): self._handle = NDArrayDataset(arr=default_array(self._data)) else: raise NotImplementedError( "C++ handle for general type object is not supported, " "given {}, expect np.ndarray".format(type(self._data))) return self._handle
class _LazyTransformDataset(Dataset): """Lazily transformed dataset.""" def __init__(self, data, fn): self._data = data self._fn = fn self.handle = None def __len__(self): return len(self._data) def __getitem__(self, idx): item = self._data[idx] if isinstance(item, tuple): return self._fn(*item) return self._fn(item) def __mx_handle__(self): if self.handle is None: from ..block import HybridBlock from ._internal import LazyTransformDataset from ...base import numeric_types if not hasattr(self._data, '__mx_handle__'): raise NotImplementedError("{} don't support backend".format(self._data)) if isinstance(self._fn, HybridBlock): item = self._data[0] self._fn.hybridize() if isinstance(item, tuple): ret = self._fn(*item) is_scalar = [int(isinstance(x, numeric_types)) for x in ret] else: ret = self._fn(item) is_scalar = [int(isinstance(ret, numeric_types))] cached_op = self._fn._cached_op self.handle = LazyTransformDataset(cached_op=cached_op, dataset=self._data.__mx_handle__(), scalar_outputs=tuple(is_scalar)) elif isinstance(self._fn, _TransformFirstClosure): if not isinstance(self._fn._fn, HybridBlock): raise NotImplementedError("Block not supported.") item = self._data[0][0] self._fn._fn.hybridize() ret = self._fn._fn(item) is_scalar = [int(isinstance(ret, numeric_types))] cached_op = self._fn._fn._cached_op self.handle = LazyTransformDataset(cached_op=cached_op, dataset=self._data.__mx_handle__(), scalar_outputs=tuple(is_scalar), transform_indices=(0,)) else: raise NotImplementedError( "C++ handle Not implemented for transforms that are not hybridizable") return self.handle class _TransformFirstClosure(object): """Use callable object instead of nested function, it can be pickled.""" def __init__(self, fn): self._fn = fn def __call__(self, x, *args): if args: return (self._fn(x),) + args return self._fn(x) class _FilteredDataset(Dataset): """Dataset with a filter applied""" def __init__(self, dataset, fn): self._dataset = dataset self._indices = [i for i, sample in enumerate(dataset) if fn(sample)] self.handle = None def __len__(self): return len(self._indices) def __getitem__(self, idx): return self._dataset[self._indices[idx]] def __mx_handle__(self): if self.handle is None: from ._internal import MXDataset, IndexedDataset if hasattr(self._dataset, '__mx_handle__'): dataset = self._dataset.__mx_handle__() elif isinstance(self._dataset, MXDataset): dataset = self._dataset else: raise NotImplementedError('{} not supported.'.format(self._dataset)) self.handle = IndexedDataset(base=dataset, indices=self._indices) return self.handle class _SampledDataset(Dataset): """Dataset with elements chosen by a sampler""" def __init__(self, dataset, sampler): self._dataset = dataset self._sampler = sampler self._indices = list(iter(sampler)) self.handle = None def __len__(self): return len(self._sampler) def __getitem__(self, idx): return self._dataset[self._indices[idx]] def __mx_handle__(self): if self.handle is None: from ._internal import MXDataset, IndexedDataset if hasattr(self._dataset, '__mx_handle__'): dataset = self._dataset.__mx_handle__() elif isinstance(self._dataset, MXDataset): dataset = self._dataset else: raise NotImplementedError('{} not supported.'.format(self._dataset)) self.handle = IndexedDataset(base=dataset, indices=self._indices) return self.handle
[docs]class ArrayDataset(Dataset): """A dataset that combines multiple dataset-like objects, e.g. Datasets, lists, arrays, etc. The i-th sample is defined as `(x1[i], x2[i], ...)`. Parameters ---------- *args : one or more dataset-like objects The data arrays. """ def __init__(self, *args): assert len(args) > 0, "Needs at least 1 arrays" self._length = len(args[0]) self._data = [] for i, data in enumerate(args): assert len(data) == self._length, \ f"All arrays must have the same length; array[0] has length {self._length} " \ f"while array[{i+1}] has {len(data)}." if isinstance(data, ndarray.NDArray) and len(data.shape) == 1: data = data.asnumpy() self._data.append(data) self.handle = None def __getitem__(self, idx): if len(self._data) == 1: return self._data[0][idx] else: return tuple(data[idx] for data in self._data) def __len__(self): return self._length def __mx_handle__(self): if self.handle is None: from ._internal import MXDataset, NDArrayDataset, GroupDataset datasets = [] for data in self._data: if isinstance(data, MXDataset): datasets.append(data) elif hasattr(data, '__mx_handle__'): datasets.append(data.__mx_handle__()) else: datasets.append(NDArrayDataset(arr=default_array(data))) self.handle = GroupDataset(datasets=datasets) return self.handle
[docs]class RecordFileDataset(Dataset): """A dataset wrapping over a RecordIO (.rec) file. Each sample is a string representing the raw content of an record. Parameters ---------- filename : str Path to rec file. """ def __init__(self, filename): self.idx_file = os.path.splitext(filename)[0] + '.idx' self.filename = filename self._record = recordio.MXIndexedRecordIO(self.idx_file, self.filename, 'r') def __getitem__(self, idx): return self._record.read_idx(self._record.keys[idx]) def __len__(self): return len(self._record.keys) def __mx_handle__(self): from ._internal import RecordFileDataset as _RecordFileDataset return _RecordFileDataset(rec_file=self.filename, idx_file=self.idx_file)
class _DownloadedDataset(Dataset): """Base class for MNIST, cifar10, etc.""" def __init__(self, root, transform): super(_DownloadedDataset, self).__init__() if transform is not None: raise DeprecationWarning( 'Directly apply transform to dataset is deprecated. ' 'Please use dataset.transform() or dataset.transform_first() instead...') self._transform = transform self._data = None self._label = None root = os.path.expanduser(root) self._root = root if not os.path.isdir(root): os.makedirs(root) self._get_data() self.handle = None def __getitem__(self, idx): if self._transform is not None: return self._transform(self._data[idx], self._label[idx]) return self._data[idx], self._label[idx] def __len__(self): return len(self._label) def _get_data(self): raise NotImplementedError def __mx_handle__(self): if self.handle is None: from ._internal import NDArrayDataset, GroupDataset self.handle = GroupDataset( datasets=(NDArrayDataset(arr=default_array(self._data)), NDArrayDataset(arr=default_array(self._label)))) return self.handle