Source code for mxnet.gluon.data.dataset
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=
"""Dataset container."""
__all__ = ['Dataset', 'SimpleDataset', 'ArrayDataset',
'RecordFileDataset']
import os
from ... import recordio, ndarray
[docs]class Dataset(object):
"""Abstract dataset class. All datasets should have this interface.
Subclasses need to override `__getitem__`, which returns the i-th
element, and `__len__`, which returns the total number elements.
.. note:: An mxnet or numpy array can be directly used as a dataset.
"""
def __getitem__(self, idx):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
[docs] def filter(self, fn):
"""Returns a new dataset with samples filtered by the
filter function `fn`.
Note that if the Dataset is the result of a lazily transformed one with
transform(lazy=False), the filter is eagerly applied to the transformed
samples without materializing the transformed result. That is, the
transformation will be applied again whenever a sample is retrieved after
filter().
Parameters
----------
fn : callable
A filter function that takes a sample as input and
returns a boolean. Samples that return False are discarded.
Returns
-------
Dataset
The filtered dataset.
"""
from . import FilterSampler
return _SampledDataset(self, FilterSampler(fn, self))
[docs] def shard(self, num_shards, index):
"""Returns a new dataset includes only 1/num_shards of this dataset.
For distributed training, be sure to shard before you randomize the dataset
(such as shuffle), if you want each worker to reach a unique subset.
Parameters
----------
num_shards : int
A integer representing the number of data shards.
index : int
A integer representing the index of the current shard.
Returns
-------
Dataset
The result dataset.
"""
assert index < num_shards, 'Shard index of out bound: %d out of %d'%(index, num_shards)
assert num_shards > 0, 'Number of shards must be greater than 0'
assert index >= 0, 'Index must be non-negative'
length = len(self)
shard_len = length // num_shards
rest = length % num_shards
# Compute the start index for this partition
start = shard_len * index + min(index, rest)
# Compute the end index for this partition
end = start + shard_len + (index < rest)
from . import SequentialSampler
return _SampledDataset(self, SequentialSampler(end - start, start))
[docs] def take(self, count):
"""Returns a new dataset with at most `count` number of samples in it.
Parameters
----------
count : int or None
A integer representing the number of elements of this dataset that
should be taken to form the new dataset. If count is None, or if count
is greater than the size of this dataset, the new dataset will contain
all elements of this dataset.
Returns
-------
Dataset
The result dataset.
"""
if count is None or count > len(self):
count = len(self)
from . import SequentialSampler
return _SampledDataset(self, SequentialSampler(count))
[docs] def sample(self, sampler):
"""Returns a new dataset with elements sampled by the sampler.
Parameters
----------
sampler : Sampler
A Sampler that returns the indices of sampled elements.
Returns
-------
Dataset
The result dataset.
"""
from . import Sampler
if not isinstance(sampler, Sampler):
raise TypeError('Invalid sampler type: %s. Expected gluon.data.Sampler instead.'%
type(sampler))
return _SampledDataset(self, sampler)
[docs] def transform(self, fn, lazy=True):
"""Returns a new dataset with each sample transformed by the
transformer function `fn`.
Parameters
----------
fn : callable
A transformer function that takes a sample as input and
returns the transformed sample.
lazy : bool, default True
If False, transforms all samples at once. Otherwise,
transforms each sample on demand. Note that if `fn`
is stochastic, you must set lazy to True or you will
get the same result on all epochs.
Returns
-------
Dataset
The transformed dataset.
"""
trans = _LazyTransformDataset(self, fn)
if lazy:
return trans
return SimpleDataset([i for i in trans])
[docs] def transform_first(self, fn, lazy=True):
"""Returns a new dataset with the first element of each sample
transformed by the transformer function `fn`.
This is useful, for example, when you only want to transform data
while keeping label as is.
Parameters
----------
fn : callable
A transformer function that takes the first elemtn of a sample
as input and returns the transformed element.
lazy : bool, default True
If False, transforms all samples at once. Otherwise,
transforms each sample on demand. Note that if `fn`
is stochastic, you must set lazy to True or you will
get the same result on all epochs.
Returns
-------
Dataset
The transformed dataset.
"""
return self.transform(_TransformFirstClosure(fn), lazy)
[docs]class SimpleDataset(Dataset):
"""Simple Dataset wrapper for lists and arrays.
Parameters
----------
data : dataset-like object
Any object that implements `len()` and `[]`.
"""
def __init__(self, data):
self._data = data
def __len__(self):
return len(self._data)
def __getitem__(self, idx):
return self._data[idx]
class _LazyTransformDataset(Dataset):
"""Lazily transformed dataset."""
def __init__(self, data, fn):
self._data = data
self._fn = fn
def __len__(self):
return len(self._data)
def __getitem__(self, idx):
item = self._data[idx]
if isinstance(item, tuple):
return self._fn(*item)
return self._fn(item)
class _TransformFirstClosure(object):
"""Use callable object instead of nested function, it can be pickled."""
def __init__(self, fn):
self._fn = fn
def __call__(self, x, *args):
if args:
return (self._fn(x),) + args
return self._fn(x)
class _FilteredDataset(Dataset):
"""Dataset with a filter applied"""
def __init__(self, dataset, fn):
self._dataset = dataset
self._indices = [i for i, sample in enumerate(dataset) if fn(sample)]
def __len__(self):
return len(self._indices)
def __getitem__(self, idx):
return self._dataset[self._indices[idx]]
class _SampledDataset(Dataset):
"""Dataset with elements chosen by a sampler"""
def __init__(self, dataset, sampler):
self._dataset = dataset
self._sampler = sampler
self._indices = list(iter(sampler))
def __len__(self):
return len(self._sampler)
def __getitem__(self, idx):
return self._dataset[self._indices[idx]]
[docs]class ArrayDataset(Dataset):
"""A dataset that combines multiple dataset-like objects, e.g.
Datasets, lists, arrays, etc.
The i-th sample is defined as `(x1[i], x2[i], ...)`.
Parameters
----------
*args : one or more dataset-like objects
The data arrays.
"""
def __init__(self, *args):
assert len(args) > 0, "Needs at least 1 arrays"
self._length = len(args[0])
self._data = []
for i, data in enumerate(args):
assert len(data) == self._length, \
"All arrays must have the same length; array[0] has length %d " \
"while array[%d] has %d." % (self._length, i+1, len(data))
if isinstance(data, ndarray.NDArray) and len(data.shape) == 1:
data = data.asnumpy()
self._data.append(data)
def __getitem__(self, idx):
if len(self._data) == 1:
return self._data[0][idx]
else:
return tuple(data[idx] for data in self._data)
def __len__(self):
return self._length
[docs]class RecordFileDataset(Dataset):
"""A dataset wrapping over a RecordIO (.rec) file.
Each sample is a string representing the raw content of an record.
Parameters
----------
filename : str
Path to rec file.
"""
def __init__(self, filename):
self.idx_file = os.path.splitext(filename)[0] + '.idx'
self.filename = filename
self._record = recordio.MXIndexedRecordIO(self.idx_file, self.filename, 'r')
def __getitem__(self, idx):
return self._record.read_idx(self._record.keys[idx])
def __len__(self):
return len(self._record.keys)
class _DownloadedDataset(Dataset):
"""Base class for MNIST, cifar10, etc."""
def __init__(self, root, transform):
super(_DownloadedDataset, self).__init__()
self._transform = transform
self._data = None
self._label = None
root = os.path.expanduser(root)
self._root = root
if not os.path.isdir(root):
os.makedirs(root)
self._get_data()
def __getitem__(self, idx):
if self._transform is not None:
return self._transform(self._data[idx], self._label[idx])
return self._data[idx], self._label[idx]
def __len__(self):
return len(self._label)
def _get_data(self):
raise NotImplementedError