gluon.data¶
Dataset utilities.
Datasets¶
Abstract dataset class. |
|
|
A dataset that combines multiple dataset-like objects, e.g. |
|
A dataset wrapping over a RecordIO (.rec) file. |
|
Simple Dataset wrapper for lists and arrays. |
Sampling¶
Base class for samplers. |
|
|
Samples elements from [start, start+length) sequentially. |
|
Samples elements from [0, length) randomly without replacement. |
|
Wraps over another Sampler and return mini-batches of samples. |
DataLoader¶
|
Loads data from a dataset and returns mini-batches of data. |
API Reference¶
Dataset utilities.
Classes
|
A dataset that combines multiple dataset-like objects, e.g. |
|
Wraps over another Sampler and return mini-batches of samples. |
|
Loads data from a dataset and returns mini-batches of data. |
Abstract dataset class. |
|
|
Samples elements from a Dataset for which fn returns True. |
|
Samples elements from [0, length) randomly without replacement. |
|
A dataset wrapping over a RecordIO (.rec) file. |
Base class for samplers. |
|
|
Samples elements from [start, start+length) sequentially. |
|
Simple Dataset wrapper for lists and arrays. |
-
class
mxnet.gluon.data.
ArrayDataset
(*args)[source]¶ Bases:
mxnet.gluon.data.dataset.Dataset
A dataset that combines multiple dataset-like objects, e.g. Datasets, lists, arrays, etc.
The i-th sample is defined as (x1[i], x2[i], …).
- Parameters
*args (one or more dataset-like objects) – The data arrays.
-
class
mxnet.gluon.data.
BatchSampler
(sampler, batch_size, last_batch='keep')[source]¶ Bases:
mxnet.gluon.data.sampler.Sampler
Wraps over another Sampler and return mini-batches of samples.
- Parameters
sampler (Sampler) – The source Sampler.
batch_size (int) – Size of mini-batch.
last_batch ({'keep', 'discard', 'rollover'}) –
Specifies how the last batch is handled if batch_size does not evenly divide sequence length.
If ‘keep’, the last batch will be returned directly, but will contain less element than batch_size requires.
If ‘discard’, the last batch will be discarded.
If ‘rollover’, the remaining elements will be rolled over to the next iteration.
Examples
>>> sampler = gluon.data.SequentialSampler(10) >>> batch_sampler = gluon.data.BatchSampler(sampler, 3, 'keep') >>> list(batch_sampler) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
-
class
mxnet.gluon.data.
DataLoader
(dataset, batch_size=None, shuffle=False, sampler=None, last_batch=None, batch_sampler=None, batchify_fn=None, num_workers=0, pin_memory=False, pin_device_id=0, prefetch=None, thread_pool=False, timeout=120, auto_reload=False)[source]¶ Bases:
object
Loads data from a dataset and returns mini-batches of data.
- Parameters
dataset (Dataset) – Source dataset. Note that numpy and mxnet arrays can be directly used as a Dataset.
batch_size (int) – Size of mini-batch.
shuffle (bool) – Whether to shuffle the samples.
sampler (Sampler) – The sampler to use. Either specify sampler or shuffle, not both.
last_batch ({'keep', 'discard', 'rollover'}) –
How to handle the last batch if batch_size does not evenly divide len(dataset).
keep - A batch with less samples than previous batches is returned. discard - The last batch is discarded if its incomplete. rollover - The remaining samples are rolled over to the next epoch.
batch_sampler (Sampler) – A sampler that returns mini-batches. Do not specify batch_size, shuffle, sampler, and last_batch if batch_sampler is specified.
batchify_fn (callable) –
Callback function to allow users to specify how to merge samples into a batch. Defaults to default_batchify_fn:
def default_batchify_fn(data): if isinstance(data[0], nd.NDArray): return nd.stack(*data) elif isinstance(data[0], tuple): data = zip(*data) return [default_batchify_fn(i) for i in data] else: data = np.asarray(data) return nd.array(data, dtype=data.dtype)
num_workers (int, default 0) – The number of multiprocessing workers to use for data preprocessing.
pin_memory (boolean, default False) – If
True
, the dataloader will copy NDArrays into pinned memory before returning them. Copying from CPU pinned memory to GPU is faster than from normal CPU memory.pin_device_id (int, default 0) – The device id to use for allocating pinned memory if pin_memory is
True
prefetch (int, default is num_workers * 2) – The number of prefetching batches only works if num_workers > 0. If prefetch > 0, it allow worker process to prefetch certain batches before acquiring data from iterators. Note that using large prefetching batch will provide smoother bootstrapping performance, but will consume more shared_memory. Using smaller number may forfeit the purpose of using multiple worker processes, try reduce num_workers in this case. By default it defaults to num_workers * 2.
thread_pool (bool, default False) – If
True
, use threading pool instead of multiprocessing pool. Using threadpool can avoid shared memory usage. If DataLoader is more IO bounded or GIL is not a killing problem, threadpool version may achieve better performance than multiprocessing.timeout (int, default is 120) – The timeout in seconds for each worker to fetch a batch data. Only modify this number unless you are experiencing timeout and you know it’s due to slow data loading. Sometimes full shared_memory will cause all workers to hang and causes timeout. In these cases please reduce num_workers or increase system shared_memory size instead.
auto_reload (bool, default is True) – control whether prefetch data after a batch is ended.
Example –
from mxnet.gluon.data import DataLoader, ArrayDataset (>>>) –
train_data = ArrayDataset([i for i in range(10)],[9-i for i in range(10)]) (>>>) –
def transform_train(sample) (>>>) –
if sample == 0 (..) –
return sample (..) –
.. –
train_iter = DataLoader(train_data.transform_first(transform_train), (>>>) –
auto_reload=False, batch_size=1,num_workers=1) (..) –
# no prefetch is performed, the prefetch & autoload start after (>>>) –
# train_iter.__iter__() is called. (>>>) –
for i in train_iter (>>>) –
data here ((pre)fetching) –
train_iter = DataLoader(train_data.transform_first(transform_train), –
batch_size=1,num_workers=1) (..) –
data here –
it = iter(train_iter) # nothing is generated since lazy-evaluation occurs (>>>) –
it2 = iter(train_iter) (>>>) –
it3 = iter(train_iter) (>>>) –
it4 = iter(train_iter) (>>>) –
_ = next(it2) # the first iter we are using is the prefetched iter. (>>>) –
_ = next(it) # since the prefetched iter is consumed, we have to fetch data for it. (>>>) –
data here –
_ = [None for _ in it3] (>>>) –
data here –
data here –
# Here, 2 prefetches are triggered, one is fetching the first batch of it3 and (>>>) –
# another is when it3 yield its last item, a prefetch is automatically performed. (>>>) –
_ = [None for _ in it] (>>>) –
# no prefetch is happened since train_loader has already prefetch data. (>>>) –
_ = next(it4) (>>>) –
# since the prefetch is performed, it4 become the prefetched iter. (>>>) –
>>> –
test_data = ArrayDataset([i for i in range(10)],[9-i for i in range(10)]) (>>>) –
test_iter = DataLoader(test_data, batch_size=1,num_workers=1) (>>>) –
for epoch in range(200) (>>>) –
# there is almost no difference between it and the default DataLoader (..) –
for data, label in train_iter (..) –
# training... (..) –
for data, label in test_iter (..) –
# testing... (..) –
Methods
clean
()Remove its prefetched iter, the prefetch step will start after call its __iter__()
refresh
()Refresh its iter, fetch data again from its dataset
-
class
mxnet.gluon.data.
Dataset
[source]¶ Bases:
object
Abstract dataset class. All datasets should have this interface.
Subclasses need to override __getitem__, which returns the i-th element, and __len__, which returns the total number elements.
Note
An mxnet or numpy array can be directly used as a dataset.
Methods
filter
(fn)Returns a new dataset with samples filtered by the filter function fn.
sample
(sampler)Returns a new dataset with elements sampled by the sampler.
shard
(num_shards, index)Returns a new dataset includes only 1/num_shards of this dataset.
take
(count)Returns a new dataset with at most count number of samples in it.
transform
(fn[, lazy])Returns a new dataset with each sample transformed by the transformer function fn.
transform_first
(fn[, lazy])Returns a new dataset with the first element of each sample transformed by the transformer function fn.
-
filter
(fn)[source]¶ Returns a new dataset with samples filtered by the filter function fn.
Note that if the Dataset is the result of a lazily transformed one with transform(lazy=False), the filter is eagerly applied to the transformed samples without materializing the transformed result. That is, the transformation will be applied again whenever a sample is retrieved after filter().
- Parameters
fn (callable) – A filter function that takes a sample as input and returns a boolean. Samples that return False are discarded.
- Returns
The filtered dataset.
- Return type
-
shard
(num_shards, index)[source]¶ Returns a new dataset includes only 1/num_shards of this dataset.
For distributed training, be sure to shard before you randomize the dataset (such as shuffle), if you want each worker to reach a unique subset.
- Parameters
num_shards (int) – A integer representing the number of data shards.
index (int) – A integer representing the index of the current shard.
- Returns
The result dataset.
- Return type
-
take
(count)[source]¶ Returns a new dataset with at most count number of samples in it.
- Parameters
count (int or None) – A integer representing the number of elements of this dataset that should be taken to form the new dataset. If count is None, or if count is greater than the size of this dataset, the new dataset will contain all elements of this dataset.
- Returns
The result dataset.
- Return type
-
transform
(fn, lazy=True)[source]¶ Returns a new dataset with each sample transformed by the transformer function fn.
- Parameters
fn (callable) – A transformer function that takes a sample as input and returns the transformed sample.
lazy (bool, default True) – If False, transforms all samples at once. Otherwise, transforms each sample on demand. Note that if fn is stochastic, you must set lazy to True or you will get the same result on all epochs.
- Returns
The transformed dataset.
- Return type
-
transform_first
(fn, lazy=True)[source]¶ Returns a new dataset with the first element of each sample transformed by the transformer function fn.
This is useful, for example, when you only want to transform data while keeping label as is.
- Parameters
fn (callable) – A transformer function that takes the first elemtn of a sample as input and returns the transformed element.
lazy (bool, default True) – If False, transforms all samples at once. Otherwise, transforms each sample on demand. Note that if fn is stochastic, you must set lazy to True or you will get the same result on all epochs.
- Returns
The transformed dataset.
- Return type
-
-
class
mxnet.gluon.data.
FilterSampler
(fn, dataset)[source]¶ Bases:
mxnet.gluon.data.sampler.Sampler
Samples elements from a Dataset for which fn returns True.
- Parameters
fn (callable) – A callable function that takes a sample and returns a boolean
dataset (Dataset) – The dataset to filter.
-
class
mxnet.gluon.data.
RandomSampler
(length)[source]¶ Bases:
mxnet.gluon.data.sampler.Sampler
Samples elements from [0, length) randomly without replacement.
- Parameters
length (int) – Length of the sequence.
-
class
mxnet.gluon.data.
RecordFileDataset
(filename)[source]¶ Bases:
mxnet.gluon.data.dataset.Dataset
A dataset wrapping over a RecordIO (.rec) file.
Each sample is a string representing the raw content of an record.
- Parameters
filename (str) – Path to rec file.
-
class
mxnet.gluon.data.
Sampler
[source]¶ Bases:
object
Base class for samplers.
All samplers should subclass Sampler and define __iter__ and __len__ methods.
-
class
mxnet.gluon.data.
SequentialSampler
(length, start=0)[source]¶ Bases:
mxnet.gluon.data.sampler.Sampler
Samples elements from [start, start+length) sequentially.
- Parameters
length (int) – Length of the sequence.
start (int, default is 0) – The start of the sequence index.
-
class
mxnet.gluon.data.
SimpleDataset
(data)[source]¶ Bases:
mxnet.gluon.data.dataset.Dataset
Simple Dataset wrapper for lists and arrays.
- Parameters
data (dataset-like object) – Any object that implements len() and [].