Source code for mxnet.gluon.data.dataset
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=
"""Dataset container."""
__all__ = ['Dataset', 'SimpleDataset', 'ArrayDataset',
'RecordFileDataset']
import os
from ... import recordio, ndarray
from ...util import default_array
[docs]class Dataset(object):
"""Abstract dataset class. All datasets should have this interface.
Subclasses need to override `__getitem__`, which returns the i-th
element, and `__len__`, which returns the total number elements.
.. note:: An mxnet or numpy array can be directly used as a dataset.
"""
def __getitem__(self, idx):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
[docs] def filter(self, fn):
"""Returns a new dataset with samples filtered by the
filter function `fn`.
Note that if the Dataset is the result of a lazily transformed one with
transform(lazy=False), the filter is eagerly applied to the transformed
samples without materializing the transformed result. That is, the
transformation will be applied again whenever a sample is retrieved after
filter().
Parameters
----------
fn : callable
A filter function that takes a sample as input and
returns a boolean. Samples that return False are discarded.
Returns
-------
Dataset
The filtered dataset.
"""
from . import FilterSampler
return _SampledDataset(self, FilterSampler(fn, self))
[docs] def shard(self, num_shards, index):
"""Returns a new dataset includes only 1/num_shards of this dataset.
For distributed training, be sure to shard before you randomize the dataset
(such as shuffle), if you want each worker to reach a unique subset.
Parameters
----------
num_shards : int
A integer representing the number of data shards.
index : int
A integer representing the index of the current shard.
Returns
-------
Dataset
The result dataset.
"""
assert index < num_shards, f'Shard index of out bound: {index} out of {num_shards}'
assert num_shards > 0, 'Number of shards must be greater than 0'
assert index >= 0, 'Index must be non-negative'
length = len(self)
shard_len = length // num_shards
rest = length % num_shards
# Compute the start index for this partition
start = shard_len * index + min(index, rest)
# Compute the end index for this partition
end = start + shard_len + (index < rest)
from . import SequentialSampler
return _SampledDataset(self, SequentialSampler(end - start, start))
[docs] def take(self, count):
"""Returns a new dataset with at most `count` number of samples in it.
Parameters
----------
count : int or None
A integer representing the number of elements of this dataset that
should be taken to form the new dataset. If count is None, or if count
is greater than the size of this dataset, the new dataset will contain
all elements of this dataset.
Returns
-------
Dataset
The result dataset.
"""
if count is None or count > len(self):
count = len(self)
from . import SequentialSampler
return _SampledDataset(self, SequentialSampler(count))
[docs] def sample(self, sampler):
"""Returns a new dataset with elements sampled by the sampler.
Parameters
----------
sampler : Sampler
A Sampler that returns the indices of sampled elements.
Returns
-------
Dataset
The result dataset.
"""
from . import Sampler
if not isinstance(sampler, Sampler):
raise TypeError(f'Invalid sampler type: {type(sampler)}. Expected gluon.data.Sampler instead.')
return _SampledDataset(self, sampler)
[docs] def transform(self, fn, lazy=True):
"""Returns a new dataset with each sample transformed by the
transformer function `fn`.
Parameters
----------
fn : callable
A transformer function that takes a sample as input and
returns the transformed sample.
lazy : bool, default True
If False, transforms all samples at once. Otherwise,
transforms each sample on demand. Note that if `fn`
is stochastic, you must set lazy to True or you will
get the same result on all epochs.
Returns
-------
Dataset
The transformed dataset.
"""
trans = _LazyTransformDataset(self, fn)
if lazy:
return trans
return SimpleDataset([i for i in trans])
[docs] def transform_first(self, fn, lazy=True):
"""Returns a new dataset with the first element of each sample
transformed by the transformer function `fn`.
This is mostly applicable when each sample contains two components
- features and label, i.e., (X, y), and you only want to transform
the first element X (i.e., the features) while keeping the label y
unchanged.
Parameters
----------
fn : callable
A transformer function that takes the first element of a sample
as input and returns the transformed element.
lazy : bool, default True
If False, transforms all samples at once. Otherwise,
transforms each sample on demand. Note that if `fn`
is stochastic, you must set lazy to True or you will
get the same result on all epochs.
Returns
-------
Dataset
The transformed dataset.
"""
return self.transform(_TransformFirstClosure(fn), lazy)
[docs]class SimpleDataset(Dataset):
"""Simple Dataset wrapper for lists and arrays.
Parameters
----------
data : dataset-like object
Any object that implements `len()` and `[]`.
"""
def __init__(self, data):
self._data = data
self._handle = None
def __len__(self):
return len(self._data)
def __getitem__(self, idx):
return self._data[idx]
def __mx_handle__(self):
if self._handle is None:
import numpy as np
from ._internal import NDArrayDataset
if isinstance(self._data, (np.ndarray, ndarray.NDArray)):
self._handle = NDArrayDataset(arr=default_array(self._data))
else:
raise NotImplementedError(
"C++ handle for general type object is not supported, "
"given {}, expect np.ndarray".format(type(self._data)))
return self._handle
class _LazyTransformDataset(Dataset):
"""Lazily transformed dataset."""
def __init__(self, data, fn):
self._data = data
self._fn = fn
self.handle = None
def __len__(self):
return len(self._data)
def __getitem__(self, idx):
item = self._data[idx]
if isinstance(item, tuple):
return self._fn(*item)
return self._fn(item)
def __mx_handle__(self):
if self.handle is None:
from ..block import HybridBlock
from ._internal import LazyTransformDataset
from ...base import numeric_types
if not hasattr(self._data, '__mx_handle__'):
raise NotImplementedError("{} don't support backend".format(self._data))
if isinstance(self._fn, HybridBlock):
item = self._data[0]
self._fn.hybridize()
if isinstance(item, tuple):
ret = self._fn(*item)
is_scalar = [int(isinstance(x, numeric_types)) for x in ret]
else:
ret = self._fn(item)
is_scalar = [int(isinstance(ret, numeric_types))]
cached_op = self._fn._cached_op
self.handle = LazyTransformDataset(cached_op=cached_op,
dataset=self._data.__mx_handle__(),
scalar_outputs=tuple(is_scalar))
elif isinstance(self._fn, _TransformFirstClosure):
if not isinstance(self._fn._fn, HybridBlock):
raise NotImplementedError("Block not supported.")
item = self._data[0][0]
self._fn._fn.hybridize()
ret = self._fn._fn(item)
is_scalar = [int(isinstance(ret, numeric_types))]
cached_op = self._fn._fn._cached_op
self.handle = LazyTransformDataset(cached_op=cached_op,
dataset=self._data.__mx_handle__(),
scalar_outputs=tuple(is_scalar),
transform_indices=(0,))
else:
raise NotImplementedError(
"C++ handle Not implemented for transforms that are not hybridizable")
return self.handle
class _TransformFirstClosure(object):
"""Use callable object instead of nested function, it can be pickled."""
def __init__(self, fn):
self._fn = fn
def __call__(self, x, *args):
if args:
return (self._fn(x),) + args
return self._fn(x)
class _FilteredDataset(Dataset):
"""Dataset with a filter applied"""
def __init__(self, dataset, fn):
self._dataset = dataset
self._indices = [i for i, sample in enumerate(dataset) if fn(sample)]
self.handle = None
def __len__(self):
return len(self._indices)
def __getitem__(self, idx):
return self._dataset[self._indices[idx]]
def __mx_handle__(self):
if self.handle is None:
from ._internal import MXDataset, IndexedDataset
if hasattr(self._dataset, '__mx_handle__'):
dataset = self._dataset.__mx_handle__()
elif isinstance(self._dataset, MXDataset):
dataset = self._dataset
else:
raise NotImplementedError('{} not supported.'.format(self._dataset))
self.handle = IndexedDataset(base=dataset,
indices=self._indices)
return self.handle
class _SampledDataset(Dataset):
"""Dataset with elements chosen by a sampler"""
def __init__(self, dataset, sampler):
self._dataset = dataset
self._sampler = sampler
self._indices = list(iter(sampler))
self.handle = None
def __len__(self):
return len(self._sampler)
def __getitem__(self, idx):
return self._dataset[self._indices[idx]]
def __mx_handle__(self):
if self.handle is None:
from ._internal import MXDataset, IndexedDataset
if hasattr(self._dataset, '__mx_handle__'):
dataset = self._dataset.__mx_handle__()
elif isinstance(self._dataset, MXDataset):
dataset = self._dataset
else:
raise NotImplementedError('{} not supported.'.format(self._dataset))
self.handle = IndexedDataset(base=dataset,
indices=self._indices)
return self.handle
[docs]class ArrayDataset(Dataset):
"""A dataset that combines multiple dataset-like objects, e.g.
Datasets, lists, arrays, etc.
The i-th sample is defined as `(x1[i], x2[i], ...)`.
Parameters
----------
*args : one or more dataset-like objects
The data arrays.
"""
def __init__(self, *args):
assert len(args) > 0, "Needs at least 1 arrays"
self._length = len(args[0])
self._data = []
for i, data in enumerate(args):
assert len(data) == self._length, \
f"All arrays must have the same length; array[0] has length {self._length} " \
f"while array[{i+1}] has {len(data)}."
if isinstance(data, ndarray.NDArray) and len(data.shape) == 1:
data = data.asnumpy()
self._data.append(data)
self.handle = None
def __getitem__(self, idx):
if len(self._data) == 1:
return self._data[0][idx]
else:
return tuple(data[idx] for data in self._data)
def __len__(self):
return self._length
def __mx_handle__(self):
if self.handle is None:
from ._internal import MXDataset, NDArrayDataset, GroupDataset
datasets = []
for data in self._data:
if isinstance(data, MXDataset):
datasets.append(data)
elif hasattr(data, '__mx_handle__'):
datasets.append(data.__mx_handle__())
else:
datasets.append(NDArrayDataset(arr=default_array(data)))
self.handle = GroupDataset(datasets=datasets)
return self.handle
[docs]class RecordFileDataset(Dataset):
"""A dataset wrapping over a RecordIO (.rec) file.
Each sample is a string representing the raw content of an record.
Parameters
----------
filename : str
Path to rec file.
"""
def __init__(self, filename):
self.idx_file = os.path.splitext(filename)[0] + '.idx'
self.filename = filename
self._record = recordio.MXIndexedRecordIO(self.idx_file, self.filename, 'r')
def __getitem__(self, idx):
return self._record.read_idx(self._record.keys[idx])
def __len__(self):
return len(self._record.keys)
def __mx_handle__(self):
from ._internal import RecordFileDataset as _RecordFileDataset
return _RecordFileDataset(rec_file=self.filename, idx_file=self.idx_file)
class _DownloadedDataset(Dataset):
"""Base class for MNIST, cifar10, etc."""
def __init__(self, root, transform):
super(_DownloadedDataset, self).__init__()
if transform is not None:
raise DeprecationWarning(
'Directly apply transform to dataset is deprecated. '
'Please use dataset.transform() or dataset.transform_first() instead...')
self._transform = transform
self._data = None
self._label = None
root = os.path.expanduser(root)
self._root = root
if not os.path.isdir(root):
os.makedirs(root)
self._get_data()
self.handle = None
def __getitem__(self, idx):
if self._transform is not None:
return self._transform(self._data[idx], self._label[idx])
return self._data[idx], self._label[idx]
def __len__(self):
return len(self._label)
def _get_data(self):
raise NotImplementedError
def __mx_handle__(self):
if self.handle is None:
from ._internal import NDArrayDataset, GroupDataset
self.handle = GroupDataset(
datasets=(NDArrayDataset(arr=default_array(self._data)),
NDArrayDataset(arr=default_array(self._label))))
return self.handle
Did this page help you?
Yes
No
Thanks for your feedback!