Source code for mxnet.gluon.data.dataloader

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=
"""Dataset generator."""
__all__ = ['DataLoader']

import multiprocessing
import multiprocessing.queues
from multiprocessing.reduction import ForkingPickler
import pickle
import io
import os
import sys
import warnings
import numpy as np

from . import sampler as _sampler
from ... import nd, context


def rebuild_ndarray(*args):
    """Rebuild ndarray from pickled shared memory"""
    # pylint: disable=no-value-for-parameter
    return nd.NDArray(nd.ndarray._new_from_shared_mem(*args))


def reduce_ndarray(data):
    """Reduce ndarray to shared memory handle"""
    return rebuild_ndarray, data._to_shared_mem()

ForkingPickler.register(nd.NDArray, reduce_ndarray)


class ConnectionWrapper(object):
    """Connection wrapper for multiprocessing that supports sending
    NDArray via shared memory."""

    def __init__(self, conn):
        self.conn = conn

    def send(self, obj):
        """Send object"""
        buf = io.BytesIO()
        ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)
        self.send_bytes(buf.getvalue())

    def recv(self):
        """Receive object"""
        buf = self.recv_bytes()
        return pickle.loads(buf)

    def __getattr__(self, name):
        """Emmulate conn"""
        return getattr(self.conn, name)


class Queue(multiprocessing.queues.Queue):
    """Wrapper for multiprocessing queue that dumps NDArray with shared memory."""
    def __init__(self, *args, **kwargs):
        if sys.version_info[0] <= 2:
            super(Queue, self).__init__(*args, **kwargs)
        else:
            super(Queue, self).__init__(*args, ctx=multiprocessing.get_context(),
                                        **kwargs)
        self._reader = ConnectionWrapper(self._reader)
        self._writer = ConnectionWrapper(self._writer)
        self._send = self._writer.send
        self._recv = self._reader.recv


def default_batchify_fn(data):
    """Collate data into batch."""
    if isinstance(data[0], nd.NDArray):
        return nd.stack(*data)
    elif isinstance(data[0], tuple):
        data = zip(*data)
        return [default_batchify_fn(i) for i in data]
    else:
        data = np.asarray(data)
        return nd.array(data, dtype=data.dtype)


def default_mp_batchify_fn(data):
    """Collate data into batch. Use shared memory for stacking."""
    if isinstance(data[0], nd.NDArray):
        out = nd.empty((len(data),) + data[0].shape, dtype=data[0].dtype,
                       ctx=context.Context('cpu_shared', 0))
        return nd.stack(*data, out=out)
    elif isinstance(data[0], tuple):
        data = zip(*data)
        return [default_mp_batchify_fn(i) for i in data]
    else:
        data = np.asarray(data)
        return nd.array(data, dtype=data.dtype,
                        ctx=context.Context('cpu_shared', 0))


def worker_loop(dataset, key_queue, data_queue, batchify_fn):
    """Worker loop for multiprocessing DataLoader."""
    while True:
        idx, samples = key_queue.get()
        if idx is None:
            break
        batch = batchify_fn([dataset[i] for i in samples])
        data_queue.put((idx, batch))


[docs]class DataLoader(object): """Loads data from a dataset and returns mini-batches of data. Parameters ---------- dataset : Dataset Source dataset. Note that numpy and mxnet arrays can be directly used as a Dataset. batch_size : int Size of mini-batch. shuffle : bool Whether to shuffle the samples. sampler : Sampler The sampler to use. Either specify sampler or shuffle, not both. last_batch : {'keep', 'discard', 'rollover'} How to handle the last batch if batch_size does not evenly divide `len(dataset)`. keep - A batch with less samples than previous batches is returned. discard - The last batch is discarded if its incomplete. rollover - The remaining samples are rolled over to the next epoch. batch_sampler : Sampler A sampler that returns mini-batches. Do not specify batch_size, shuffle, sampler, and last_batch if batch_sampler is specified. batchify_fn : callable Callback function to allow users to specify how to merge samples into a batch. Defaults to `default_batchify_fn`:: def default_batchify_fn(data): if isinstance(data[0], nd.NDArray): return nd.stack(*data) elif isinstance(data[0], tuple): data = zip(*data) return [default_batchify_fn(i) for i in data] else: data = np.asarray(data) return nd.array(data, dtype=data.dtype) num_workers : int, default 0 The number of multiprocessing workers to use for data preprocessing. `num_workers > 0` is not supported on Windows yet. """ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, last_batch=None, batch_sampler=None, batchify_fn=None, num_workers=0): self._dataset = dataset if batch_sampler is None: if batch_size is None: raise ValueError("batch_size must be specified unless " \ "batch_sampler is specified") if sampler is None: if shuffle: sampler = _sampler.RandomSampler(len(dataset)) else: sampler = _sampler.SequentialSampler(len(dataset)) elif shuffle: raise ValueError("shuffle must not be specified if sampler is specified") batch_sampler = _sampler.BatchSampler( sampler, batch_size, last_batch if last_batch else 'keep') elif batch_size is not None or shuffle or sampler is not None or \ last_batch is not None: raise ValueError("batch_size, shuffle, sampler and last_batch must " \ "not be specified if batch_sampler is specified.") self._batch_sampler = batch_sampler if num_workers > 0 and os.name == 'nt': warnings.warn("DataLoader does not support num_workers > 0 on Windows yet.") num_workers = 0 self._num_workers = num_workers if batchify_fn is None: if num_workers > 0: self._batchify_fn = default_mp_batchify_fn else: self._batchify_fn = default_batchify_fn else: self._batchify_fn = batchify_fn def __iter__(self): if self._num_workers == 0: for batch in self._batch_sampler: yield self._batchify_fn([self._dataset[idx] for idx in batch]) return key_queue = Queue() data_queue = Queue(2*self._num_workers) workers = [] for _ in range(self._num_workers): worker = multiprocessing.Process( target=worker_loop, args=(self._dataset, key_queue, data_queue, self._batchify_fn)) worker.daemon = True worker.start() workers.append(worker) for idx, batch in enumerate(self._batch_sampler): key_queue.put((idx, batch)) num_batches = idx + 1 data_buffer = {} curr_idx = 0 for _ in range(num_batches): idx, batch = data_queue.get() data_buffer[idx] = batch while curr_idx in data_buffer: yield data_buffer.pop(curr_idx) curr_idx += 1 for _ in range(self._num_workers): key_queue.put((None, None)) for worker in workers: worker.join() def __len__(self): return len(self._batch_sampler)