Source code for mxnet.image.image

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=no-member, too-many-lines, redefined-builtin, protected-access, unused-import, invalid-name
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module, too-many-branches, too-many-statements
"""Read individual image files and perform augmentations."""

from __future__ import absolute_import, print_function

import sys
import os
import random
import logging
import json
import warnings
import numpy as np


try:
    import cv2
except ImportError:
    cv2 = None

from ..base import numeric_types
from .. import ndarray as nd
from ..ndarray import _internal
from ..ndarray._internal import _cvimresize as imresize
from ..ndarray._internal import _cvcopyMakeBorder as copyMakeBorder
from .. import io
from .. import recordio


def imread(filename, *args, **kwargs):
    """Read and decode an image to an NDArray.

    Note: `imread` uses OpenCV (not the CV2 Python library).
    MXNet must have been built with USE_OPENCV=1 for `imdecode` to work.

    Parameters
    ----------
    filename : str
        Name of the image file to be loaded.
    flag : {0, 1}, default 1
        1 for three channel color output. 0 for grayscale output.
    to_rgb : bool, default True
        True for RGB formatted output (MXNet default).
        False for BGR formatted output (OpenCV default).
    out : NDArray, optional
        Output buffer. Use `None` for automatic allocation.

    Returns
    -------
    NDArray
        An `NDArray` containing the image.

    Example
    -------
    >>> mx.img.imread("flower.jpg")
    

    Set `flag` parameter to 0 to get grayscale output

    >>> mx.img.imread("flower.jpg", flag=0)
    

    Set `to_rgb` parameter to 0 to get output in OpenCV format (BGR)

    >>> mx.img.imread("flower.jpg", to_rgb=0)
    
    """
    return _internal._cvimread(filename, *args, **kwargs)


def imdecode(buf, *args, **kwargs):
    """Decode an image to an NDArray.

    Note: `imdecode` uses OpenCV (not the CV2 Python library).
    MXNet must have been built with USE_OPENCV=1 for `imdecode` to work.

    Parameters
    ----------
    buf : str/bytes/bytearray or numpy.ndarray
        Binary image data as string or numpy ndarray.
    flag : int, optional, default=1
        1 for three channel color output. 0 for grayscale output.
    to_rgb : int, optional, default=1
        1 for RGB formatted output (MXNet default). 0 for BGR formatted output (OpenCV default).
    out : NDArray, optional
        Output buffer. Use `None` for automatic allocation.

    Returns
    -------
    NDArray
        An `NDArray` containing the image.

    Example
    -------
    >>> with open("flower.jpg", 'rb') as fp:
    ...     str_image = fp.read()
    ...
    >>> image = mx.img.imdecode(str_image)
    >>> image
    

    Set `flag` parameter to 0 to get grayscale output

    >>> with open("flower.jpg", 'rb') as fp:
    ...     str_image = fp.read()
    ...
    >>> image = mx.img.imdecode(str_image, flag=0)
    >>> image
    

    Set `to_rgb` parameter to 0 to get output in OpenCV format (BGR)

    >>> with open("flower.jpg", 'rb') as fp:
    ...     str_image = fp.read()
    ...
    >>> image = mx.img.imdecode(str_image, to_rgb=0)
    >>> image
    
    """
    if not isinstance(buf, nd.NDArray):
        if sys.version_info[0] == 3 and not isinstance(buf, (bytes, bytearray, np.ndarray)):
            raise ValueError('buf must be of type bytes, bytearray or numpy.ndarray,'
                             'if you would like to input type str, please convert to bytes')
        buf = nd.array(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8)

    return _internal._cvimdecode(buf, *args, **kwargs)


def scale_down(src_size, size):
    """Scales down crop size if it's larger than image size.

    If width/height of the crop is larger than the width/height of the image,
    sets the width/height to the width/height of the image.

    Parameters
    ----------
    src_size : tuple of int
        Size of the image in (width, height) format.
    size : tuple of int
        Size of the crop in (width, height) format.

    Returns
    -------
    tuple of int
        A tuple containing the scaled crop size in (width, height) format.

    Example
    --------
    >>> src_size = (640,480)
    >>> size = (720,120)
    >>> new_size = mx.img.scale_down(src_size, size)
    >>> new_size
    (640,106)
    """
    w, h = size
    sw, sh = src_size
    if sh < h:
        w, h = float(w * sh) / h, sh
    if sw < w:
        w, h = sw, float(h * sw) / w
    return int(w), int(h)


def _get_interp_method(interp, sizes=()):
    """Get the interpolation method for resize functions.
    The major purpose of this function is to wrap a random interp method selection
    and a auto-estimation method.

    Parameters
    ----------
    interp : int
        interpolation method for all resizing operations

        Possible values:
        0: Nearest Neighbors Interpolation.
        1: Bilinear interpolation.
        2: Area-based (resampling using pixel area relation). It may be a
        preferred method for image decimation, as it gives moire-free
        results. But when the image is zoomed, it is similar to the Nearest
        Neighbors method. (used by default).
        3: Bicubic interpolation over 4x4 pixel neighborhood.
        4: Lanczos interpolation over 8x8 pixel neighborhood.
        9: Cubic for enlarge, area for shrink, bilinear for others
        10: Random select from interpolation method metioned above.
        Note:
        When shrinking an image, it will generally look best with AREA-based
        interpolation, whereas, when enlarging an image, it will generally look best
        with Bicubic (slow) or Bilinear (faster but still looks OK).
        More details can be found in the documentation of OpenCV, please refer to
        http://docs.opencv.org/master/da/d54/group__imgproc__transform.html.
    sizes : tuple of int
        (old_height, old_width, new_height, new_width), if None provided, auto(9)
        will return Area(2) anyway.

    Returns
    -------
    int
        interp method from 0 to 4
    """
    if interp == 9:
        if sizes:
            assert len(sizes) == 4
            oh, ow, nh, nw = sizes
            if nh > oh and nw > ow:
                return 2
            elif nh < oh and nw < ow:
                return 3
            else:
                return 1
        else:
            return 2
    if interp == 10:
        return random.randint(0, 4)
    if interp not in (0, 1, 2, 3, 4):
        raise ValueError('Unknown interp method %d' % interp)
    return interp


def resize_short(src, size, interp=2):
    """Resizes shorter edge to size.

    Note: `resize_short` uses OpenCV (not the CV2 Python library).
    MXNet must have been built with OpenCV for `resize_short` to work.

    Resizes the original image by setting the shorter edge to size
    and setting the longer edge accordingly.
    Resizing function is called from OpenCV.

    Parameters
    ----------
    src : NDArray
        The original image.
    size : int
        The length to be set for the shorter edge.
    interp : int, optional, default=2
        Interpolation method used for resizing the image.
        Possible values:
        0: Nearest Neighbors Interpolation.
        1: Bilinear interpolation.
        2: Area-based (resampling using pixel area relation). It may be a
        preferred method for image decimation, as it gives moire-free
        results. But when the image is zoomed, it is similar to the Nearest
        Neighbors method. (used by default).
        3: Bicubic interpolation over 4x4 pixel neighborhood.
        4: Lanczos interpolation over 8x8 pixel neighborhood.
        9: Cubic for enlarge, area for shrink, bilinear for others
        10: Random select from interpolation method metioned above.
        Note:
        When shrinking an image, it will generally look best with AREA-based
        interpolation, whereas, when enlarging an image, it will generally look best
        with Bicubic (slow) or Bilinear (faster but still looks OK).
        More details can be found in the documentation of OpenCV, please refer to
        http://docs.opencv.org/master/da/d54/group__imgproc__transform.html.

    Returns
    -------
    NDArray
        An 'NDArray' containing the resized image.

    Example
    -------
    >>> with open("flower.jpeg", 'rb') as fp:
    ...     str_image = fp.read()
    ...
    >>> image = mx.img.imdecode(str_image)
    >>> image
    
    >>> size = 640
    >>> new_image = mx.img.resize_short(image, size)
    >>> new_image
    
    """
    h, w, _ = src.shape
    if h > w:
        new_h, new_w = size * h // w, size
    else:
        new_h, new_w = size, size * w // h
    return imresize(src, new_w, new_h, interp=_get_interp_method(interp, (h, w, new_h, new_w)))


def fixed_crop(src, x0, y0, w, h, size=None, interp=2):
    """Crop src at fixed location, and (optionally) resize it to size.

    Parameters
    ----------
    src : NDArray
        Input image
    x0 : int
        Left boundary of the cropping area
    y0 : int
        Top boundary of the cropping area
    w : int
        Width of the cropping area
    h : int
        Height of the cropping area
    size : tuple of (w, h)
        Optional, resize to new size after cropping
    interp : int, optional, default=2
        Interpolation method. See resize_short for details.

    Returns
    -------
    NDArray
        An `NDArray` containing the cropped image.
    """
    out = nd.crop(src, begin=(y0, x0, 0), end=(y0 + h, x0 + w, int(src.shape[2])))
    if size is not None and (w, h) != size:
        sizes = (h, w, size[1], size[0])
        out = imresize(out, *size, interp=_get_interp_method(interp, sizes))
    return out


def random_crop(src, size, interp=2):
    """Randomly crop `src` with `size` (width, height).
    Upsample result if `src` is smaller than `size`.

    Parameters
    ----------
    src: Source image `NDArray`
    size: Size of the crop formatted as (width, height). If the `size` is larger
           than the image, then the source image is upsampled to `size` and returned.
    interp: int, optional, default=2
        Interpolation method. See resize_short for details.
    Returns
    -------
    NDArray
        An `NDArray` containing the cropped image.
    Tuple
        A tuple (x, y, width, height) where (x, y) is top-left position of the crop in the
        original image and (width, height) are the dimensions of the cropped image.

    Example
    -------
    >>> im = mx.nd.array(cv2.imread("flower.jpg"))
    >>> cropped_im, rect  = mx.image.random_crop(im, (100, 100))
    >>> print cropped_im
    
    >>> print rect
    (20, 21, 100, 100)
    """

    h, w, _ = src.shape
    new_w, new_h = scale_down((w, h), size)

    x0 = random.randint(0, w - new_w)
    y0 = random.randint(0, h - new_h)

    out = fixed_crop(src, x0, y0, new_w, new_h, size, interp)
    return out, (x0, y0, new_w, new_h)


def center_crop(src, size, interp=2):
    """Crops the image `src` to the given `size` by trimming on all four
    sides and preserving the center of the image. Upsamples if `src` is smaller
    than `size`.

    .. note:: This requires MXNet to be compiled with USE_OPENCV.

    Parameters
    ----------
    src : NDArray
        Binary source image data.
    size : list or tuple of int
        The desired output image size.
    interp : int, optional, default=2
        Interpolation method. See resize_short for details.

    Returns
    -------
    NDArray
        The cropped image.
    Tuple
        (x, y, width, height) where x, y are the positions of the crop in the
        original image and width, height the dimensions of the crop.

    Example
    -------
    >>> with open("flower.jpg", 'rb') as fp:
    ...     str_image = fp.read()
    ...
    >>> image = mx.image.imdecode(str_image)
    >>> image
    
    >>> cropped_image, (x, y, width, height) = mx.image.center_crop(image, (1000, 500))
    >>> cropped_image
    
    >>> x, y, width, height
    (1241, 910, 1000, 500)
    """

    h, w, _ = src.shape
    new_w, new_h = scale_down((w, h), size)

    x0 = int((w - new_w) / 2)
    y0 = int((h - new_h) / 2)

    out = fixed_crop(src, x0, y0, new_w, new_h, size, interp)
    return out, (x0, y0, new_w, new_h)


def color_normalize(src, mean, std=None):
    """Normalize src with mean and std.

    Parameters
    ----------
    src : NDArray
        Input image
    mean : NDArray
        RGB mean to be subtracted
    std : NDArray
        RGB standard deviation to be divided

    Returns
    -------
    NDArray
        An `NDArray` containing the normalized image.
    """
    if mean is not None:
        src -= mean
    if std is not None:
        src /= std
    return src


def random_size_crop(src, size, area, ratio, interp=2, **kwargs):
    """Randomly crop src with size. Randomize area and aspect ratio.

    Parameters
    ----------
    src : NDArray
        Input image
    size : tuple of (int, int)
        Size of the crop formatted as (width, height).
    area : float in (0, 1] or tuple of (float, float)
        If tuple, minimum area and maximum area to be maintained after cropping
        If float, minimum area to be maintained after cropping, maximum area is set to 1.0
    ratio : tuple of (float, float)
        Aspect ratio range as (min_aspect_ratio, max_aspect_ratio)
    interp: int, optional, default=2
        Interpolation method. See resize_short for details.
    Returns
    -------
    NDArray
        An `NDArray` containing the cropped image.
    Tuple
        A tuple (x, y, width, height) where (x, y) is top-left position of the crop in the
        original image and (width, height) are the dimensions of the cropped image.

    """
    h, w, _ = src.shape
    src_area = h * w

    if 'min_area' in kwargs:
        warnings.warn('`min_area` is deprecated. Please use `area` instead.',
                      DeprecationWarning)
        area = kwargs.pop('min_area')
    assert not kwargs, "unexpected keyword arguments for `random_size_crop`."

    if isinstance(area, numeric_types):
        area = (area, 1.0)
    for _ in range(10):
        target_area = random.uniform(area[0], area[1]) * src_area
        new_ratio = random.uniform(*ratio)

        new_w = int(round(np.sqrt(target_area * new_ratio)))
        new_h = int(round(np.sqrt(target_area / new_ratio)))

        if random.random() < 0.5:
            new_h, new_w = new_w, new_h

        if new_w <= w and new_h <= h:
            x0 = random.randint(0, w - new_w)
            y0 = random.randint(0, h - new_h)

            out = fixed_crop(src, x0, y0, new_w, new_h, size, interp)
            return out, (x0, y0, new_w, new_h)

    # fall back to center_crop
    return center_crop(src, size, interp)


[docs]class Augmenter(object): """Image Augmenter base class""" def __init__(self, **kwargs): self._kwargs = kwargs for k, v in self._kwargs.items(): if isinstance(v, nd.NDArray): v = v.asnumpy() if isinstance(v, np.ndarray): v = v.tolist() self._kwargs[k] = v
[docs] def dumps(self): """Saves the Augmenter to string Returns ------- str JSON formatted string that describes the Augmenter. """ return json.dumps([self.__class__.__name__.lower(), self._kwargs])
def __call__(self, src): """Abstract implementation body""" raise NotImplementedError("Must override implementation.")
[docs]class SequentialAug(Augmenter): """Composing a sequential augmenter list. Parameters ---------- ts : list of augmenters A series of augmenters to be applied in sequential order. """ def __init__(self, ts): super(SequentialAug, self).__init__() self.ts = ts def dumps(self): """Override the default to avoid duplicate dump.""" return [self.__class__.__name__.lower(), [x.dumps() for x in self.ts]] def __call__(self, src): """Augmenter body""" for aug in self.ts: src = aug(src) return src
[docs]class ResizeAug(Augmenter): """Make resize shorter edge to size augmenter. Parameters ---------- size : int The length to be set for the shorter edge. interp : int, optional, default=2 Interpolation method. See resize_short for details. """ def __init__(self, size, interp=2): super(ResizeAug, self).__init__(size=size, interp=interp) self.size = size self.interp = interp def __call__(self, src): """Augmenter body""" return resize_short(src, self.size, self.interp)
[docs]class ForceResizeAug(Augmenter): """Force resize to size regardless of aspect ratio Parameters ---------- size : tuple of (int, int) The desired size as in (width, height) interp : int, optional, default=2 Interpolation method. See resize_short for details. """ def __init__(self, size, interp=2): super(ForceResizeAug, self).__init__(size=size, interp=interp) self.size = size self.interp = interp def __call__(self, src): """Augmenter body""" sizes = (src.shape[0], src.shape[1], self.size[1], self.size[0]) return imresize(src, *self.size, interp=_get_interp_method(self.interp, sizes))
[docs]class RandomCropAug(Augmenter): """Make random crop augmenter Parameters ---------- size : int The length to be set for the shorter edge. interp : int, optional, default=2 Interpolation method. See resize_short for details. """ def __init__(self, size, interp=2): super(RandomCropAug, self).__init__(size=size, interp=interp) self.size = size self.interp = interp def __call__(self, src): """Augmenter body""" return random_crop(src, self.size, self.interp)[0]
[docs]class RandomSizedCropAug(Augmenter): """Make random crop with random resizing and random aspect ratio jitter augmenter. Parameters ---------- size : tuple of (int, int) Size of the crop formatted as (width, height). area : float in (0, 1] or tuple of (float, float) If tuple, minimum area and maximum area to be maintained after cropping If float, minimum area to be maintained after cropping, maximum area is set to 1.0 ratio : tuple of (float, float) Aspect ratio range as (min_aspect_ratio, max_aspect_ratio) interp: int, optional, default=2 Interpolation method. See resize_short for details. """ def __init__(self, size, area, ratio, interp=2, **kwargs): super(RandomSizedCropAug, self).__init__(size=size, area=area, ratio=ratio, interp=interp) self.size = size if 'min_area' in kwargs: warnings.warn('`min_area` is deprecated. Please use `area` instead.', DeprecationWarning) self.area = kwargs.pop('min_area') else: self.area = area self.ratio = ratio self.interp = interp assert not kwargs, "unexpected keyword arguments for `RandomSizedCropAug`." def __call__(self, src): """Augmenter body""" return random_size_crop(src, self.size, self.area, self.ratio, self.interp)[0]
[docs]class CenterCropAug(Augmenter): """Make center crop augmenter. Parameters ---------- size : list or tuple of int The desired output image size. interp : int, optional, default=2 Interpolation method. See resize_short for details. """ def __init__(self, size, interp=2): super(CenterCropAug, self).__init__(size=size, interp=interp) self.size = size self.interp = interp def __call__(self, src): """Augmenter body""" return center_crop(src, self.size, self.interp)[0]
[docs]class RandomOrderAug(Augmenter): """Apply list of augmenters in random order Parameters ---------- ts : list of augmenters A series of augmenters to be applied in random order """ def __init__(self, ts): super(RandomOrderAug, self).__init__() self.ts = ts def dumps(self): """Override the default to avoid duplicate dump.""" return [self.__class__.__name__.lower(), [x.dumps() for x in self.ts]] def __call__(self, src): """Augmenter body""" random.shuffle(self.ts) for t in self.ts: src = t(src) return src
[docs]class BrightnessJitterAug(Augmenter): """Random brightness jitter augmentation. Parameters ---------- brightness : float The brightness jitter ratio range, [0, 1] """ def __init__(self, brightness): super(BrightnessJitterAug, self).__init__(brightness=brightness) self.brightness = brightness def __call__(self, src): """Augmenter body""" alpha = 1.0 + random.uniform(-self.brightness, self.brightness) src *= alpha return src
[docs]class ContrastJitterAug(Augmenter): """Random contrast jitter augmentation. Parameters ---------- contrast : float The contrast jitter ratio range, [0, 1] """ def __init__(self, contrast): super(ContrastJitterAug, self).__init__(contrast=contrast) self.contrast = contrast self.coef = nd.array([[[0.299, 0.587, 0.114]]]) def __call__(self, src): """Augmenter body""" alpha = 1.0 + random.uniform(-self.contrast, self.contrast) gray = src * self.coef gray = (3.0 * (1.0 - alpha) / gray.size) * nd.sum(gray) src *= alpha src += gray return src
[docs]class SaturationJitterAug(Augmenter): """Random saturation jitter augmentation. Parameters ---------- saturation : float The saturation jitter ratio range, [0, 1] """ def __init__(self, saturation): super(SaturationJitterAug, self).__init__(saturation=saturation) self.saturation = saturation self.coef = nd.array([[[0.299, 0.587, 0.114]]]) def __call__(self, src): """Augmenter body""" alpha = 1.0 + random.uniform(-self.saturation, self.saturation) gray = src * self.coef gray = nd.sum(gray, axis=2, keepdims=True) gray *= (1.0 - alpha) src *= alpha src += gray return src
[docs]class HueJitterAug(Augmenter): """Random hue jitter augmentation. Parameters ---------- hue : float The hue jitter ratio range, [0, 1] """ def __init__(self, hue): super(HueJitterAug, self).__init__(hue=hue) self.hue = hue self.tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321], [0.211, -0.523, 0.311]]) self.ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647], [1.0, -1.107, 1.705]]) def __call__(self, src): """Augmenter body. Using approximate linear transfomation described in: https://beesbuzz.biz/code/hsv_color_transforms.php """ alpha = random.uniform(-self.hue, self.hue) u = np.cos(alpha * np.pi) w = np.sin(alpha * np.pi) bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]]) t = np.dot(np.dot(self.ityiq, bt), self.tyiq).T src = nd.dot(src, nd.array(t)) return src
[docs]class ColorJitterAug(RandomOrderAug): """Apply random brightness, contrast and saturation jitter in random order. Parameters ---------- brightness : float The brightness jitter ratio range, [0, 1] contrast : float The contrast jitter ratio range, [0, 1] saturation : float The saturation jitter ratio range, [0, 1] """ def __init__(self, brightness, contrast, saturation): ts = [] if brightness > 0: ts.append(BrightnessJitterAug(brightness)) if contrast > 0: ts.append(ContrastJitterAug(contrast)) if saturation > 0: ts.append(SaturationJitterAug(saturation)) super(ColorJitterAug, self).__init__(ts)
[docs]class LightingAug(Augmenter): """Add PCA based noise. Parameters ---------- alphastd : float Noise level eigval : 3x1 np.array Eigen values eigvec : 3x3 np.array Eigen vectors """ def __init__(self, alphastd, eigval, eigvec): super(LightingAug, self).__init__(alphastd=alphastd, eigval=eigval, eigvec=eigvec) self.alphastd = alphastd self.eigval = eigval self.eigvec = eigvec def __call__(self, src): """Augmenter body""" alpha = np.random.normal(0, self.alphastd, size=(3,)) rgb = np.dot(self.eigvec * alpha, self.eigval) src += nd.array(rgb) return src
[docs]class ColorNormalizeAug(Augmenter): """Mean and std normalization. Parameters ---------- mean : NDArray RGB mean to be subtracted std : NDArray RGB standard deviation to be divided """ def __init__(self, mean, std): super(ColorNormalizeAug, self).__init__(mean=mean, std=std) self.mean = mean if mean is None or isinstance(mean, nd.NDArray) else nd.array(mean) self.std = std if std is None or isinstance(std, nd.NDArray) else nd.array(std) def __call__(self, src): """Augmenter body""" return color_normalize(src, self.mean, self.std)
[docs]class RandomGrayAug(Augmenter): """Randomly convert to gray image. Parameters ---------- p : float Probability to convert to grayscale """ def __init__(self, p): super(RandomGrayAug, self).__init__(p=p) self.p = p self.mat = nd.array([[0.21, 0.21, 0.21], [0.72, 0.72, 0.72], [0.07, 0.07, 0.07]]) def __call__(self, src): """Augmenter body""" if random.random() < self.p: src = nd.dot(src, self.mat) return src
[docs]class HorizontalFlipAug(Augmenter): """Random horizontal flip. Parameters ---------- p : float Probability to flip image horizontally """ def __init__(self, p): super(HorizontalFlipAug, self).__init__(p=p) self.p = p def __call__(self, src): """Augmenter body""" if random.random() < self.p: src = nd.flip(src, axis=1) return src
[docs]class CastAug(Augmenter): """Cast to float32""" def __init__(self, typ='float32'): super(CastAug, self).__init__(type=typ) self.typ = typ def __call__(self, src): """Augmenter body""" src = src.astype(self.typ) return src
def CreateAugmenter(data_shape, resize=0, rand_crop=False, rand_resize=False, rand_mirror=False, mean=None, std=None, brightness=0, contrast=0, saturation=0, hue=0, pca_noise=0, rand_gray=0, inter_method=2): """Creates an augmenter list. Parameters ---------- data_shape : tuple of int Shape for output data resize : int Resize shorter edge if larger than 0 at the begining rand_crop : bool Whether to enable random cropping other than center crop rand_resize : bool Whether to enable random sized cropping, require rand_crop to be enabled rand_gray : float [0, 1], probability to convert to grayscale for all channels, the number of channels will not be reduced to 1 rand_mirror : bool Whether to apply horizontal flip to image with probability 0.5 mean : np.ndarray or None Mean pixel values for [r, g, b] std : np.ndarray or None Standard deviations for [r, g, b] brightness : float Brightness jittering range (percent) contrast : float Contrast jittering range (percent) saturation : float Saturation jittering range (percent) hue : float Hue jittering range (percent) pca_noise : float Pca noise level (percent) inter_method : int, default=2(Area-based) Interpolation method for all resizing operations Possible values: 0: Nearest Neighbors Interpolation. 1: Bilinear interpolation. 2: Area-based (resampling using pixel area relation). It may be a preferred method for image decimation, as it gives moire-free results. But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). 3: Bicubic interpolation over 4x4 pixel neighborhood. 4: Lanczos interpolation over 8x8 pixel neighborhood. 9: Cubic for enlarge, area for shrink, bilinear for others 10: Random select from interpolation method metioned above. Note: When shrinking an image, it will generally look best with AREA-based interpolation, whereas, when enlarging an image, it will generally look best with Bicubic (slow) or Bilinear (faster but still looks OK). Examples -------- >>> # An example of creating multiple augmenters >>> augs = mx.image.CreateAugmenter(data_shape=(3, 300, 300), rand_mirror=True, ... mean=True, brightness=0.125, contrast=0.125, rand_gray=0.05, ... saturation=0.125, pca_noise=0.05, inter_method=10) >>> # dump the details >>> for aug in augs: ... aug.dumps() """ auglist = [] if resize > 0: auglist.append(ResizeAug(resize, inter_method)) crop_size = (data_shape[2], data_shape[1]) if rand_resize: assert rand_crop auglist.append(RandomSizedCropAug(crop_size, 0.08, (3.0 / 4.0, 4.0 / 3.0), inter_method)) elif rand_crop: auglist.append(RandomCropAug(crop_size, inter_method)) else: auglist.append(CenterCropAug(crop_size, inter_method)) if rand_mirror: auglist.append(HorizontalFlipAug(0.5)) auglist.append(CastAug()) if brightness or contrast or saturation: auglist.append(ColorJitterAug(brightness, contrast, saturation)) if hue: auglist.append(HueJitterAug(hue)) if pca_noise > 0: eigval = np.array([55.46, 4.794, 1.148]) eigvec = np.array([[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]) auglist.append(LightingAug(pca_noise, eigval, eigvec)) if rand_gray > 0: auglist.append(RandomGrayAug(rand_gray)) if mean is True: mean = nd.array([123.68, 116.28, 103.53]) elif mean is not None: assert isinstance(mean, (np.ndarray, nd.NDArray)) and mean.shape[0] in [1, 3] if std is True: std = nd.array([58.395, 57.12, 57.375]) elif std is not None: assert isinstance(std, (np.ndarray, nd.NDArray)) and std.shape[0] in [1, 3] if mean is not None or std is not None: auglist.append(ColorNormalizeAug(mean, std)) return auglist
[docs]class ImageIter(io.DataIter): """Image data iterator with a large number of augmentation choices. This iterator supports reading from both .rec files and raw image files. To load input images from .rec files, use `path_imgrec` parameter and to load from raw image files, use `path_imglist` and `path_root` parameters. To use data partition (for distributed training) or shuffling, specify `path_imgidx` parameter. Parameters ---------- batch_size : int Number of examples per batch. data_shape : tuple Data shape in (channels, height, width) format. For now, only RGB image with 3 channels is supported. label_width : int, optional Number of labels per example. The default label width is 1. path_imgrec : str Path to image record file (.rec). Created with tools/im2rec.py or bin/im2rec. path_imglist : str Path to image list (.lst). Created with tools/im2rec.py or with custom script. Format: Tab separated record of index, one or more labels and relative_path_from_root. imglist: list A list of images with the label(s). Each item is a list [imagelabel: float or list of float, imgpath]. path_root : str Root folder of image files. path_imgidx : str Path to image index file. Needed for partition and shuffling when using .rec source. shuffle : bool Whether to shuffle all images at the start of each iteration or not. Can be slow for HDD. part_index : int Partition index. num_parts : int Total number of partitions. data_name : str Data name for provided symbols. label_name : str Label name for provided symbols. dtype : str Label data type. Default: float32. Other options: int32, int64, float64 last_batch_handle : str, optional How to handle the last batch. This parameter can be 'pad'(default), 'discard' or 'roll_over'. If 'pad', the last batch will be padded with data starting from the begining If 'discard', the last batch will be discarded If 'roll_over', the remaining elements will be rolled over to the next iteration kwargs : ... More arguments for creating augmenter. See mx.image.CreateAugmenter. """ def __init__(self, batch_size, data_shape, label_width=1, path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, data_name='data', label_name='softmax_label', dtype='float32', last_batch_handle='pad', **kwargs): super(ImageIter, self).__init__() assert path_imgrec or path_imglist or (isinstance(imglist, list)) assert dtype in ['int32', 'float32', 'int64', 'float64'], dtype + ' label not supported' num_threads = os.environ.get('MXNET_CPU_WORKER_NTHREADS', 1) logging.info('Using %s threads for decoding...', str(num_threads)) logging.info('Set enviroment variable MXNET_CPU_WORKER_NTHREADS to a' ' larger number to use more threads.') class_name = self.__class__.__name__ if path_imgrec: logging.info('%s: loading recordio %s...', class_name, path_imgrec) if path_imgidx: self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type self.imgidx = list(self.imgrec.keys) else: self.imgrec = recordio.MXRecordIO(path_imgrec, 'r') # pylint: disable=redefined-variable-type self.imgidx = None else: self.imgrec = None if path_imglist: logging.info('%s: loading image list %s...', class_name, path_imglist) with open(path_imglist) as fin: imglist = {} imgkeys = [] for line in iter(fin.readline, ''): line = line.strip().split('\t') label = nd.array(line[1:-1], dtype=dtype) key = int(line[0]) imglist[key] = (label, line[-1]) imgkeys.append(key) self.imglist = imglist elif isinstance(imglist, list): logging.info('%s: loading image list...', class_name) result = {} imgkeys = [] index = 1 for img in imglist: key = str(index) # pylint: disable=redefined-variable-type index += 1 if len(img) > 2: label = nd.array(img[:-1], dtype=dtype) elif isinstance(img[0], numeric_types): label = nd.array([img[0]], dtype=dtype) else: label = nd.array(img[0], dtype=dtype) result[key] = (label, img[-1]) imgkeys.append(str(key)) self.imglist = result else: self.imglist = None self.path_root = path_root self.check_data_shape(data_shape) self.provide_data = [(data_name, (batch_size,) + data_shape)] if label_width > 1: self.provide_label = [(label_name, (batch_size, label_width))] else: self.provide_label = [(label_name, (batch_size,))] self.batch_size = batch_size self.data_shape = data_shape self.label_width = label_width self.shuffle = shuffle if self.imgrec is None: self.seq = imgkeys elif shuffle or num_parts > 1 or path_imgidx: assert self.imgidx is not None self.seq = self.imgidx else: self.seq = None if num_parts > 1: assert part_index < num_parts N = len(self.seq) C = N // num_parts self.seq = self.seq[part_index * C:(part_index + 1) * C] if aug_list is None: self.auglist = CreateAugmenter(data_shape, **kwargs) else: self.auglist = aug_list self.cur = 0 self._allow_read = True self.last_batch_handle = last_batch_handle self.num_image = len(self.seq) if self.seq is not None else None self._cache_data = None self._cache_label = None self._cache_idx = None self.reset()
[docs] def reset(self): """Resets the iterator to the beginning of the data.""" if self.seq is not None and self.shuffle: random.shuffle(self.seq) if self.last_batch_handle != 'roll_over' or \ self._cache_data is None: if self.imgrec is not None: self.imgrec.reset() self.cur = 0 if self._allow_read is False: self._allow_read = True
[docs] def hard_reset(self): """Resets the iterator and ignore roll over data""" if self.seq is not None and self.shuffle: random.shuffle(self.seq) if self.imgrec is not None: self.imgrec.reset() self.cur = 0 self._allow_read = True self._cache_data = None self._cache_label = None self._cache_idx = None
[docs] def next_sample(self): """Helper function for reading in next sample.""" if self._allow_read is False: raise StopIteration if self.seq is not None: if self.cur < self.num_image: idx = self.seq[self.cur] else: if self.last_batch_handle != 'discard': self.cur = 0 raise StopIteration self.cur += 1 if self.imgrec is not None: s = self.imgrec.read_idx(idx) header, img = recordio.unpack(s) if self.imglist is None: return header.label, img else: return self.imglist[idx][0], img else: label, fname = self.imglist[idx] return label, self.read_image(fname) else: s = self.imgrec.read() if s is None: if self.last_batch_handle != 'discard': self.imgrec.reset() raise StopIteration header, img = recordio.unpack(s) return header.label, img
def _batchify(self, batch_data, batch_label, start=0): """Helper function for batchifying data""" i = start batch_size = self.batch_size try: while i < batch_size: label, s = self.next_sample() data = self.imdecode(s) try: self.check_valid_image(data) except RuntimeError as e: logging.debug('Invalid image, skipping: %s', str(e)) continue data = self.augmentation_transform(data) assert i < batch_size, 'Batch size must be multiples of augmenter output length' batch_data[i] = self.postprocess_data(data) batch_label[i] = label i += 1 except StopIteration: if not i: raise StopIteration return i
[docs] def next(self): """Returns the next batch of data.""" batch_size = self.batch_size c, h, w = self.data_shape # if last batch data is rolled over if self._cache_data is not None: # check both the data and label have values assert self._cache_label is not None, "_cache_label didn't have values" assert self._cache_idx is not None, "_cache_idx didn't have values" batch_data = self._cache_data batch_label = self._cache_label i = self._cache_idx # clear the cache data else: batch_data = nd.zeros((batch_size, c, h, w)) batch_label = nd.empty(self.provide_label[0][1]) i = self._batchify(batch_data, batch_label) # calculate the padding pad = batch_size - i # handle padding for the last batch if pad != 0: if self.last_batch_handle == 'discard': raise StopIteration # if the option is 'roll_over', throw StopIteration and cache the data elif self.last_batch_handle == 'roll_over' and \ self._cache_data is None: self._cache_data = batch_data self._cache_label = batch_label self._cache_idx = i raise StopIteration else: _ = self._batchify(batch_data, batch_label, i) if self.last_batch_handle == 'pad': self._allow_read = False else: self._cache_data = None self._cache_label = None self._cache_idx = None return io.DataBatch([batch_data], [batch_label], pad=pad)
[docs] def check_data_shape(self, data_shape): """Checks if the input data shape is valid""" if not len(data_shape) == 3: raise ValueError('data_shape should have length 3, with dimensions CxHxW') if not data_shape[0] == 3: raise ValueError('This iterator expects inputs to have 3 channels.')
[docs] def check_valid_image(self, data): """Checks if the input data is valid""" if len(data[0].shape) == 0: raise RuntimeError('Data shape is wrong')
[docs] def imdecode(self, s): """Decodes a string or byte string to an NDArray. See mx.img.imdecode for more details.""" def locate(): """Locate the image file/index if decode fails.""" if self.seq is not None: idx = self.seq[(self.cur % self.num_image) - 1] else: idx = (self.cur % self.num_image) - 1 if self.imglist is not None: _, fname = self.imglist[idx] msg = "filename: {}".format(fname) else: msg = "index: {}".format(idx) return "Broken image " + msg try: img = imdecode(s) except Exception as e: raise RuntimeError("{}, {}".format(locate(), e)) return img
[docs] def read_image(self, fname): """Reads an input image `fname` and returns the decoded raw bytes. Examples -------- >>> dataIter.read_image('Face.jpg') # returns decoded raw bytes. """ with open(os.path.join(self.path_root, fname), 'rb') as fin: img = fin.read() return img
[docs] def augmentation_transform(self, data): """Transforms input data with specified augmentation.""" for aug in self.auglist: data = aug(data) return data
[docs] def postprocess_data(self, datum): """Final postprocessing step before image is loaded into the batch.""" return nd.transpose(datum, axes=(2, 0, 1))