Source code for mxnet.image.image

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=no-member, too-many-lines, redefined-builtin, protected-access, unused-import, invalid-name
# pylint: disable=too-many-arguments, too-many-locals, no-name-in-module, too-many-branches, too-many-statements
"""Read individual image files and perform augmentations."""

from __future__ import absolute_import, print_function

import sys
import os
import random
import logging
import json
import warnings
import numpy as np
from .. import numpy as _mx_np  # pylint: disable=reimported


try:
    import cv2
except ImportError:
    cv2 = None

from ..base import numeric_types
from .. import ndarray as nd
from ..ndarray import _internal
from .. import io
from .. import recordio
from .. util import is_np_array
from ..ndarray.numpy import _internal as _npi


[docs]def imread(filename, *args, **kwargs): """Read and decode an image to an NDArray. .. note:: `imread` uses OpenCV (not the CV2 Python library). MXNet must have been built with USE_OPENCV=1 for `imdecode` to work. Parameters ---------- filename : str Name of the image file to be loaded. flag : {0, 1}, default 1 1 for three channel color output. 0 for grayscale output. to_rgb : bool, default True True for RGB formatted output (MXNet default). False for BGR formatted output (OpenCV default). out : NDArray, optional Output buffer. Use `None` for automatic allocation. Returns ------- NDArray An `NDArray` containing the image. Example ------- >>> mx.img.imread("flower.jpg") <NDArray 224x224x3 @cpu(0)> Set `flag` parameter to 0 to get grayscale output >>> mx.img.imread("flower.jpg", flag=0) <NDArray 224x224x1 @cpu(0)> Set `to_rgb` parameter to 0 to get output in OpenCV format (BGR) >>> mx.img.imread("flower.jpg", to_rgb=0) <NDArray 224x224x3 @cpu(0)> """ if is_np_array(): read_fn = _npi.cvimread else: read_fn = _internal._cvimread return read_fn(filename, *args, **kwargs)
[docs]def imresize(src, w, h, *args, **kwargs): r"""Resize image with OpenCV. .. note:: `imresize` uses OpenCV (not the CV2 Python library). MXNet must have been built with USE_OPENCV=1 for `imresize` to work. Parameters ---------- src : NDArray source image w : int, required Width of resized image. h : int, required Height of resized image. interp : int, optional, default=1 Interpolation method (default=cv2.INTER_LINEAR). Possible values: 0: Nearest Neighbors Interpolation. 1: Bilinear interpolation. 2: Bicubic interpolation over 4x4 pixel neighborhood. 3: Area-based (resampling using pixel area relation). It may be a preferred method for image decimation, as it gives moire-free results. But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). 4: Lanczos interpolation over 8x8 pixel neighborhood. 9: Cubic for enlarge, area for shrink, bilinear for others 10: Random select from interpolation method metioned above. Note: When shrinking an image, it will generally look best with AREA-based interpolation, whereas, when enlarging an image, it will generally look best with Bicubic (slow) or Bilinear (faster but still looks OK). More details can be found in the documentation of OpenCV, please refer to http://docs.opencv.org/master/da/d54/group__imgproc__transform.html. out : NDArray, optional The output NDArray to hold the result. Returns ------- out : NDArray or list of NDArrays The output of this function. Example ------- >>> with open("flower.jpeg", 'rb') as fp: ... str_image = fp.read() ... >>> image = mx.img.imdecode(str_image) >>> image <NDArray 2321x3482x3 @cpu(0)> >>> new_image = mx.img.resize(image, 240, 360) >>> new_image <NDArray 240x360x3 @cpu(0)> """ resize_fn = _npi.cvimresize if is_np_array() else _internal._cvimresize return resize_fn(src, w, h, *args, **kwargs)
[docs]def imdecode(buf, *args, **kwargs): """Decode an image to an NDArray. .. note:: `imdecode` uses OpenCV (not the CV2 Python library). MXNet must have been built with USE_OPENCV=1 for `imdecode` to work. Parameters ---------- buf : str/bytes/bytearray or numpy.ndarray Binary image data as string or numpy ndarray. flag : int, optional, default=1 1 for three channel color output. 0 for grayscale output. to_rgb : int, optional, default=1 1 for RGB formatted output (MXNet default). 0 for BGR formatted output (OpenCV default). out : NDArray, optional Output buffer. Use `None` for automatic allocation. Returns ------- NDArray An `NDArray` containing the image. Example ------- >>> with open("flower.jpg", 'rb') as fp: ... str_image = fp.read() ... >>> image = mx.img.imdecode(str_image) >>> image <NDArray 224x224x3 @cpu(0)> Set `flag` parameter to 0 to get grayscale output >>> with open("flower.jpg", 'rb') as fp: ... str_image = fp.read() ... >>> image = mx.img.imdecode(str_image, flag=0) >>> image <NDArray 224x224x1 @cpu(0)> Set `to_rgb` parameter to 0 to get output in OpenCV format (BGR) >>> with open("flower.jpg", 'rb') as fp: ... str_image = fp.read() ... >>> image = mx.img.imdecode(str_image, to_rgb=0) >>> image <NDArray 224x224x3 @cpu(0)> """ if not isinstance(buf, nd.NDArray): if sys.version_info[0] == 3 and not isinstance(buf, (bytes, bytearray, np.ndarray)): raise ValueError('buf must be of type bytes, bytearray or numpy.ndarray,' 'if you would like to input type str, please convert to bytes') array_fn = _mx_np.array if is_np_array() else nd.array buf = array_fn(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8) cvimdecode = _npi.cvimdecode if is_np_array() else _internal._cvimdecode return cvimdecode(buf, *args, **kwargs)
[docs]def scale_down(src_size, size): """Scales down crop size if it's larger than image size. If width/height of the crop is larger than the width/height of the image, sets the width/height to the width/height of the image. Parameters ---------- src_size : tuple of int Size of the image in (width, height) format. size : tuple of int Size of the crop in (width, height) format. Returns ------- tuple of int A tuple containing the scaled crop size in (width, height) format. Example -------- >>> src_size = (640,480) >>> size = (720,120) >>> new_size = mx.img.scale_down(src_size, size) >>> new_size (640,106) """ w, h = size sw, sh = src_size if sh < h: w, h = float(w * sh) / h, sh if sw < w: w, h = sw, float(h * sw) / w return int(w), int(h)
def copyMakeBorder(src, top, bot, left, right, *args, **kwargs): """Pad image border with OpenCV. Parameters ---------- src : NDArray source image top : int, required Top margin. bot : int, required Bottom margin. left : int, required Left margin. right : int, required Right margin. type : int, optional, default='0' Filling type (default=cv2.BORDER_CONSTANT). 0 - cv2.BORDER_CONSTANT - Adds a constant colored border. 1 - cv2.BORDER_REFLECT - Border will be mirror reflection of the border elements, like this : fedcba|abcdefgh|hgfedcb 2 - cv2.BORDER_REFLECT_101 or cv.BORDER_DEFAULT - Same as above, but with a slight change, like this : gfedcb|abcdefgh|gfedcba 3 - cv2.BORDER_REPLICATE - Last element is replicated throughout, like this: aaaaaa|abcdefgh|hhhhhhh 4 - cv2.BORDER_WRAP - it will look like this : cdefgh|abcdefgh|abcdefg value : double, optional, default=0 (Deprecated! Use ``values`` instead.) Fill with single value. values : tuple of <double>, optional, default=[] Fill with value(RGB[A] or gray), up to 4 channels. out : NDArray, optional The output NDArray to hold the result. Returns ------- out : NDArray or list of NDArrays The output of this function. Example -------- >>> with open("flower.jpeg", 'rb') as fp: ... str_image = fp.read() ... >>> image = mx.img.imdecode(str_image) >>> image <NDArray 2321x3482x3 @cpu(0)> >>> new_image = mx_border = mx.image.copyMakeBorder(mx_img, 1, 2, 3, 4, type=0) >>> new_image <NDArray 2324x3489x3 @cpu(0)> """ return _internal._cvcopyMakeBorder(src, top, bot, left, right, *args, **kwargs) def _get_interp_method(interp, sizes=()): """Get the interpolation method for resize functions. The major purpose of this function is to wrap a random interp method selection and a auto-estimation method. Parameters ---------- interp : int interpolation method for all resizing operations Possible values: 0: Nearest Neighbors Interpolation. 1: Bilinear interpolation. 2: Bicubic interpolation over 4x4 pixel neighborhood. 3: Area-based (resampling using pixel area relation). It may be a preferred method for image decimation, as it gives moire-free results. But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). 4: Lanczos interpolation over 8x8 pixel neighborhood. 9: Cubic for enlarge, area for shrink, bilinear for others 10: Random select from interpolation method metioned above. Note: When shrinking an image, it will generally look best with AREA-based interpolation, whereas, when enlarging an image, it will generally look best with Bicubic (slow) or Bilinear (faster but still looks OK). More details can be found in the documentation of OpenCV, please refer to http://docs.opencv.org/master/da/d54/group__imgproc__transform.html. sizes : tuple of int (old_height, old_width, new_height, new_width), if None provided, auto(9) will return Area(2) anyway. Returns ------- int interp method from 0 to 4 """ if interp == 9: if sizes: assert len(sizes) == 4 oh, ow, nh, nw = sizes if nh > oh and nw > ow: return 2 elif nh < oh and nw < ow: return 3 else: return 1 else: return 2 if interp == 10: return random.randint(0, 4) if interp not in (0, 1, 2, 3, 4): raise ValueError('Unknown interp method %d' % interp) return interp
[docs]def resize_short(src, size, interp=2): """Resizes shorter edge to size. .. note:: `resize_short` uses OpenCV (not the CV2 Python library). MXNet must have been built with OpenCV for `resize_short` to work. Resizes the original image by setting the shorter edge to size and setting the longer edge accordingly. Resizing function is called from OpenCV. Parameters ---------- src : NDArray The original image. size : int The length to be set for the shorter edge. interp : int, optional, default=2 Interpolation method used for resizing the image. Possible values: 0: Nearest Neighbors Interpolation. 1: Bilinear interpolation. 2: Bicubic interpolation over 4x4 pixel neighborhood. 3: Area-based (resampling using pixel area relation). It may be a preferred method for image decimation, as it gives moire-free results. But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). 4: Lanczos interpolation over 8x8 pixel neighborhood. 9: Cubic for enlarge, area for shrink, bilinear for others 10: Random select from interpolation method metioned above. Note: When shrinking an image, it will generally look best with AREA-based interpolation, whereas, when enlarging an image, it will generally look best with Bicubic (slow) or Bilinear (faster but still looks OK). More details can be found in the documentation of OpenCV, please refer to http://docs.opencv.org/master/da/d54/group__imgproc__transform.html. Returns ------- NDArray An 'NDArray' containing the resized image. Example ------- >>> with open("flower.jpeg", 'rb') as fp: ... str_image = fp.read() ... >>> image = mx.img.imdecode(str_image) >>> image <NDArray 2321x3482x3 @cpu(0)> >>> size = 640 >>> new_image = mx.img.resize_short(image, size) >>> new_image <NDArray 2321x3482x3 @cpu(0)> """ h, w, _ = src.shape if h > w: new_h, new_w = size * h // w, size else: new_h, new_w = size, size * w // h return imresize(src, new_w, new_h, interp=_get_interp_method(interp, (h, w, new_h, new_w)))
[docs]def fixed_crop(src, x0, y0, w, h, size=None, interp=2): """Crop src at fixed location, and (optionally) resize it to size. Parameters ---------- src : NDArray Input image x0 : int Left boundary of the cropping area y0 : int Top boundary of the cropping area w : int Width of the cropping area h : int Height of the cropping area size : tuple of (w, h) Optional, resize to new size after cropping interp : int, optional, default=2 Interpolation method. See resize_short for details. Returns ------- NDArray An `NDArray` containing the cropped image. """ out = src[y0:y0+h, x0:x0+w] if size is not None and (w, h) != size: sizes = (h, w, size[1], size[0]) out = imresize(out, *size, interp=_get_interp_method(interp, sizes)) return out
[docs]def random_crop(src, size, interp=2): """Randomly crop `src` with `size` (width, height). Upsample result if `src` is smaller than `size`. Parameters ---------- src: Source image `NDArray` size: Size of the crop formatted as (width, height). If the `size` is larger than the image, then the source image is upsampled to `size` and returned. interp: int, optional, default=2 Interpolation method. See resize_short for details. Returns ------- NDArray An `NDArray` containing the cropped image. Tuple A tuple (x, y, width, height) where (x, y) is top-left position of the crop in the original image and (width, height) are the dimensions of the cropped image. Example ------- >>> im = mx.nd.array(cv2.imread("flower.jpg")) >>> cropped_im, rect = mx.image.random_crop(im, (100, 100)) >>> print cropped_im <NDArray 100x100x1 @cpu(0)> >>> print rect (20, 21, 100, 100) """ h, w, _ = src.shape new_w, new_h = scale_down((w, h), size) x0 = random.randint(0, w - new_w) y0 = random.randint(0, h - new_h) out = fixed_crop(src, x0, y0, new_w, new_h, size, interp) return out, (x0, y0, new_w, new_h)
[docs]def center_crop(src, size, interp=2): """Crops the image `src` to the given `size` by trimming on all four sides and preserving the center of the image. Upsamples if `src` is smaller than `size`. .. note:: This requires MXNet to be compiled with USE_OPENCV. Parameters ---------- src : NDArray Binary source image data. size : list or tuple of int The desired output image size. interp : int, optional, default=2 Interpolation method. See resize_short for details. Returns ------- NDArray The cropped image. Tuple (x, y, width, height) where x, y are the positions of the crop in the original image and width, height the dimensions of the crop. Example ------- >>> with open("flower.jpg", 'rb') as fp: ... str_image = fp.read() ... >>> image = mx.image.imdecode(str_image) >>> image <NDArray 2321x3482x3 @cpu(0)> >>> cropped_image, (x, y, width, height) = mx.image.center_crop(image, (1000, 500)) >>> cropped_image <NDArray 500x1000x3 @cpu(0)> >>> x, y, width, height (1241, 910, 1000, 500) """ h, w, _ = src.shape new_w, new_h = scale_down((w, h), size) x0 = int((w - new_w) / 2) y0 = int((h - new_h) / 2) out = fixed_crop(src, x0, y0, new_w, new_h, size, interp) return out, (x0, y0, new_w, new_h)
[docs]def color_normalize(src, mean, std=None): """Normalize src with mean and std. Parameters ---------- src : NDArray Input image mean : NDArray RGB mean to be subtracted std : NDArray RGB standard deviation to be divided Returns ------- NDArray An `NDArray` containing the normalized image. """ if mean is not None: src -= mean if std is not None: src /= std return src
[docs]def random_size_crop(src, size, area, ratio, interp=2, **kwargs): """Randomly crop src with size. Randomize area and aspect ratio. Parameters ---------- src : NDArray Input image size : tuple of (int, int) Size of the crop formatted as (width, height). area : float in (0, 1] or tuple of (float, float) If tuple, minimum area and maximum area to be maintained after cropping If float, minimum area to be maintained after cropping, maximum area is set to 1.0 ratio : tuple of (float, float) Aspect ratio range as (min_aspect_ratio, max_aspect_ratio) interp: int, optional, default=2 Interpolation method. See resize_short for details. Returns ------- NDArray An `NDArray` containing the cropped image. Tuple A tuple (x, y, width, height) where (x, y) is top-left position of the crop in the original image and (width, height) are the dimensions of the cropped image. """ h, w, _ = src.shape src_area = h * w if 'min_area' in kwargs: warnings.warn('`min_area` is deprecated. Please use `area` instead.', DeprecationWarning) area = kwargs.pop('min_area') assert not kwargs, "unexpected keyword arguments for `random_size_crop`." if isinstance(area, numeric_types): area = (area, 1.0) for _ in range(10): target_area = random.uniform(area[0], area[1]) * src_area log_ratio = (np.log(ratio[0]), np.log(ratio[1])) new_ratio = np.exp(random.uniform(*log_ratio)) new_w = int(round(np.sqrt(target_area * new_ratio))) new_h = int(round(np.sqrt(target_area / new_ratio))) if new_w <= w and new_h <= h: x0 = random.randint(0, w - new_w) y0 = random.randint(0, h - new_h) out = fixed_crop(src, x0, y0, new_w, new_h, size, interp) return out, (x0, y0, new_w, new_h) # fall back to center_crop return center_crop(src, size, interp)
[docs]class Augmenter(object): """Image Augmenter base class""" def __init__(self, **kwargs): self._kwargs = kwargs for k, v in self._kwargs.items(): if isinstance(v, nd.NDArray): v = v.asnumpy() if isinstance(v, np.ndarray): v = v.tolist() self._kwargs[k] = v
[docs] def dumps(self): """Saves the Augmenter to string Returns ------- str JSON formatted string that describes the Augmenter. """ return json.dumps([self.__class__.__name__.lower(), self._kwargs])
def __call__(self, src): """Abstract implementation body""" raise NotImplementedError("Must override implementation.")
[docs]class SequentialAug(Augmenter): """Composing a sequential augmenter list. Parameters ---------- ts : list of augmenters A series of augmenters to be applied in sequential order. """ def __init__(self, ts): super(SequentialAug, self).__init__() self.ts = ts
[docs] def dumps(self): """Override the default to avoid duplicate dump.""" return [self.__class__.__name__.lower(), [x.dumps() for x in self.ts]]
def __call__(self, src): """Augmenter body""" for aug in self.ts: src = aug(src) return src
[docs]class ResizeAug(Augmenter): """Make resize shorter edge to size augmenter. Parameters ---------- size : int The length to be set for the shorter edge. interp : int, optional, default=2 Interpolation method. See resize_short for details. """ def __init__(self, size, interp=2): super(ResizeAug, self).__init__(size=size, interp=interp) self.size = size self.interp = interp def __call__(self, src): """Augmenter body""" return resize_short(src, self.size, self.interp)
[docs]class ForceResizeAug(Augmenter): """Force resize to size regardless of aspect ratio Parameters ---------- size : tuple of (int, int) The desired size as in (width, height) interp : int, optional, default=2 Interpolation method. See resize_short for details. """ def __init__(self, size, interp=2): super(ForceResizeAug, self).__init__(size=size, interp=interp) self.size = size self.interp = interp def __call__(self, src): """Augmenter body""" sizes = (src.shape[0], src.shape[1], self.size[1], self.size[0]) return imresize(src, *self.size, interp=_get_interp_method(self.interp, sizes))
[docs]class RandomCropAug(Augmenter): """Make random crop augmenter Parameters ---------- size : int The length to be set for the shorter edge. interp : int, optional, default=2 Interpolation method. See resize_short for details. """ def __init__(self, size, interp=2): super(RandomCropAug, self).__init__(size=size, interp=interp) self.size = size self.interp = interp def __call__(self, src): """Augmenter body""" return random_crop(src, self.size, self.interp)[0]
[docs]class RandomSizedCropAug(Augmenter): """Make random crop with random resizing and random aspect ratio jitter augmenter. Parameters ---------- size : tuple of (int, int) Size of the crop formatted as (width, height). area : float in (0, 1] or tuple of (float, float) If tuple, minimum area and maximum area to be maintained after cropping If float, minimum area to be maintained after cropping, maximum area is set to 1.0 ratio : tuple of (float, float) Aspect ratio range as (min_aspect_ratio, max_aspect_ratio) interp: int, optional, default=2 Interpolation method. See resize_short for details. """ def __init__(self, size, area, ratio, interp=2, **kwargs): super(RandomSizedCropAug, self).__init__(size=size, area=area, ratio=ratio, interp=interp) self.size = size if 'min_area' in kwargs: warnings.warn('`min_area` is deprecated. Please use `area` instead.', DeprecationWarning) self.area = kwargs.pop('min_area') else: self.area = area self.ratio = ratio self.interp = interp assert not kwargs, "unexpected keyword arguments for `RandomSizedCropAug`." def __call__(self, src): """Augmenter body""" return random_size_crop(src, self.size, self.area, self.ratio, self.interp)[0]
[docs]class CenterCropAug(Augmenter): """Make center crop augmenter. Parameters ---------- size : list or tuple of int The desired output image size. interp : int, optional, default=2 Interpolation method. See resize_short for details. """ def __init__(self, size, interp=2): super(CenterCropAug, self).__init__(size=size, interp=interp) self.size = size self.interp = interp def __call__(self, src): """Augmenter body""" return center_crop(src, self.size, self.interp)[0]
[docs]class RandomOrderAug(Augmenter): """Apply list of augmenters in random order Parameters ---------- ts : list of augmenters A series of augmenters to be applied in random order """ def __init__(self, ts): super(RandomOrderAug, self).__init__() self.ts = ts
[docs] def dumps(self): """Override the default to avoid duplicate dump.""" return [self.__class__.__name__.lower(), [x.dumps() for x in self.ts]]
def __call__(self, src): """Augmenter body""" random.shuffle(self.ts) for t in self.ts: src = t(src) return src
[docs]class BrightnessJitterAug(Augmenter): """Random brightness jitter augmentation. Parameters ---------- brightness : float The brightness jitter ratio range, [0, 1] """ def __init__(self, brightness): super(BrightnessJitterAug, self).__init__(brightness=brightness) self.brightness = brightness def __call__(self, src): """Augmenter body""" alpha = 1.0 + random.uniform(-self.brightness, self.brightness) src *= alpha return src
[docs]class ContrastJitterAug(Augmenter): """Random contrast jitter augmentation. Parameters ---------- contrast : float The contrast jitter ratio range, [0, 1] """ def __init__(self, contrast): super(ContrastJitterAug, self).__init__(contrast=contrast) self.contrast = contrast self.coef = nd.array([[[0.299, 0.587, 0.114]]]) def __call__(self, src): """Augmenter body""" alpha = 1.0 + random.uniform(-self.contrast, self.contrast) gray = src * self.coef gray = (3.0 * (1.0 - alpha) / gray.size) * nd.sum(gray) src *= alpha src += gray return src
[docs]class SaturationJitterAug(Augmenter): """Random saturation jitter augmentation. Parameters ---------- saturation : float The saturation jitter ratio range, [0, 1] """ def __init__(self, saturation): super(SaturationJitterAug, self).__init__(saturation=saturation) self.saturation = saturation self.coef = nd.array([[[0.299, 0.587, 0.114]]]) def __call__(self, src): """Augmenter body""" alpha = 1.0 + random.uniform(-self.saturation, self.saturation) gray = src * self.coef gray = nd.sum(gray, axis=2, keepdims=True) gray *= (1.0 - alpha) src *= alpha src += gray return src
[docs]class HueJitterAug(Augmenter): """Random hue jitter augmentation. Parameters ---------- hue : float The hue jitter ratio range, [0, 1] """ def __init__(self, hue): super(HueJitterAug, self).__init__(hue=hue) self.hue = hue self.tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321], [0.211, -0.523, 0.311]]) self.ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647], [1.0, -1.107, 1.705]]) def __call__(self, src): """Augmenter body. Using approximate linear transfomation described in: https://beesbuzz.biz/code/hsv_color_transforms.php """ alpha = random.uniform(-self.hue, self.hue) u = np.cos(alpha * np.pi) w = np.sin(alpha * np.pi) bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]]) t = np.dot(np.dot(self.ityiq, bt), self.tyiq).T src = nd.dot(src, nd.array(t)) return src
[docs]class ColorJitterAug(RandomOrderAug): """Apply random brightness, contrast and saturation jitter in random order. Parameters ---------- brightness : float The brightness jitter ratio range, [0, 1] contrast : float The contrast jitter ratio range, [0, 1] saturation : float The saturation jitter ratio range, [0, 1] """ def __init__(self, brightness, contrast, saturation): ts = [] if brightness > 0: ts.append(BrightnessJitterAug(brightness)) if contrast > 0: ts.append(ContrastJitterAug(contrast)) if saturation > 0: ts.append(SaturationJitterAug(saturation)) super(ColorJitterAug, self).__init__(ts)
[docs]class LightingAug(Augmenter): """Add PCA based noise. Parameters ---------- alphastd : float Noise level eigval : 3x1 np.array Eigen values eigvec : 3x3 np.array Eigen vectors """ def __init__(self, alphastd, eigval, eigvec): super(LightingAug, self).__init__(alphastd=alphastd, eigval=eigval, eigvec=eigvec) self.alphastd = alphastd self.eigval = eigval self.eigvec = eigvec def __call__(self, src): """Augmenter body""" alpha = np.random.normal(0, self.alphastd, size=(3,)) rgb = np.dot(self.eigvec * alpha, self.eigval) src += nd.array(rgb) return src
[docs]class ColorNormalizeAug(Augmenter): """Mean and std normalization. Parameters ---------- mean : NDArray RGB mean to be subtracted std : NDArray RGB standard deviation to be divided """ def __init__(self, mean, std): super(ColorNormalizeAug, self).__init__(mean=mean, std=std) self.mean = mean if mean is None or isinstance(mean, nd.NDArray) else nd.array(mean) self.std = std if std is None or isinstance(std, nd.NDArray) else nd.array(std) def __call__(self, src): """Augmenter body""" return color_normalize(src, self.mean, self.std)
[docs]class RandomGrayAug(Augmenter): """Randomly convert to gray image. Parameters ---------- p : float Probability to convert to grayscale """ def __init__(self, p): super(RandomGrayAug, self).__init__(p=p) self.p = p self.mat = nd.array([[0.21, 0.21, 0.21], [0.72, 0.72, 0.72], [0.07, 0.07, 0.07]]) def __call__(self, src): """Augmenter body""" if random.random() < self.p: src = nd.dot(src, self.mat) return src
[docs]class HorizontalFlipAug(Augmenter): """Random horizontal flip. Parameters ---------- p : float Probability to flip image horizontally """ def __init__(self, p): super(HorizontalFlipAug, self).__init__(p=p) self.p = p def __call__(self, src): """Augmenter body""" if random.random() < self.p: src = nd.flip(src, axis=1) return src
[docs]class CastAug(Augmenter): """Cast to float32""" def __init__(self, typ='float32'): super(CastAug, self).__init__(type=typ) self.typ = typ def __call__(self, src): """Augmenter body""" src = src.astype(self.typ) return src
[docs]def CreateAugmenter(data_shape, resize=0, rand_crop=False, rand_resize=False, rand_mirror=False, mean=None, std=None, brightness=0, contrast=0, saturation=0, hue=0, pca_noise=0, rand_gray=0, inter_method=2): """Creates an augmenter list. Parameters ---------- data_shape : tuple of int Shape for output data resize : int Resize shorter edge if larger than 0 at the begining rand_crop : bool Whether to enable random cropping other than center crop rand_resize : bool Whether to enable random sized cropping, require rand_crop to be enabled rand_gray : float [0, 1], probability to convert to grayscale for all channels, the number of channels will not be reduced to 1 rand_mirror : bool Whether to apply horizontal flip to image with probability 0.5 mean : np.ndarray or None Mean pixel values for [r, g, b] std : np.ndarray or None Standard deviations for [r, g, b] brightness : float Brightness jittering range (percent) contrast : float Contrast jittering range (percent) saturation : float Saturation jittering range (percent) hue : float Hue jittering range (percent) pca_noise : float Pca noise level (percent) inter_method : int, default=2(Area-based) Interpolation method for all resizing operations Possible values: 0: Nearest Neighbors Interpolation. 1: Bilinear interpolation. 2: Bicubic interpolation over 4x4 pixel neighborhood. 3: Area-based (resampling using pixel area relation). It may be a preferred method for image decimation, as it gives moire-free results. But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). 4: Lanczos interpolation over 8x8 pixel neighborhood. 9: Cubic for enlarge, area for shrink, bilinear for others 10: Random select from interpolation method metioned above. Note: When shrinking an image, it will generally look best with AREA-based interpolation, whereas, when enlarging an image, it will generally look best with Bicubic (slow) or Bilinear (faster but still looks OK). Examples -------- >>> # An example of creating multiple augmenters >>> augs = mx.image.CreateAugmenter(data_shape=(3, 300, 300), rand_mirror=True, ... mean=True, brightness=0.125, contrast=0.125, rand_gray=0.05, ... saturation=0.125, pca_noise=0.05, inter_method=10) >>> # dump the details >>> for aug in augs: ... aug.dumps() """ auglist = [] if resize > 0: auglist.append(ResizeAug(resize, inter_method)) crop_size = (data_shape[2], data_shape[1]) if rand_resize: assert rand_crop auglist.append(RandomSizedCropAug(crop_size, 0.08, (3.0 / 4.0, 4.0 / 3.0), inter_method)) elif rand_crop: auglist.append(RandomCropAug(crop_size, inter_method)) else: auglist.append(CenterCropAug(crop_size, inter_method)) if rand_mirror: auglist.append(HorizontalFlipAug(0.5)) auglist.append(CastAug()) if brightness or contrast or saturation: auglist.append(ColorJitterAug(brightness, contrast, saturation)) if hue: auglist.append(HueJitterAug(hue)) if pca_noise > 0: eigval = np.array([55.46, 4.794, 1.148]) eigvec = np.array([[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]) auglist.append(LightingAug(pca_noise, eigval, eigvec)) if rand_gray > 0: auglist.append(RandomGrayAug(rand_gray)) if mean is True: mean = nd.array([123.68, 116.28, 103.53]) elif mean is not None: assert isinstance(mean, (np.ndarray, nd.NDArray)) and mean.shape[0] in [1, 3] if std is True: std = nd.array([58.395, 57.12, 57.375]) elif std is not None: assert isinstance(std, (np.ndarray, nd.NDArray)) and std.shape[0] in [1, 3] if mean is not None or std is not None: auglist.append(ColorNormalizeAug(mean, std)) return auglist
[docs]class ImageIter(io.DataIter): """Image data iterator with a large number of augmentation choices. This iterator supports reading from both .rec files and raw image files. To load input images from .rec files, use `path_imgrec` parameter and to load from raw image files, use `path_imglist` and `path_root` parameters. To use data partition (for distributed training) or shuffling, specify `path_imgidx` parameter. Parameters ---------- batch_size : int Number of examples per batch. data_shape : tuple Data shape in (channels, height, width) format. For now, only RGB image with 3 channels is supported. label_width : int, optional Number of labels per example. The default label width is 1. path_imgrec : str Path to image record file (.rec). Created with tools/im2rec.py or bin/im2rec. path_imglist : str Path to image list (.lst). Created with tools/im2rec.py or with custom script. Format: Tab separated record of index, one or more labels and relative_path_from_root. imglist: list A list of images with the label(s). Each item is a list [imagelabel: float or list of float, imgpath]. path_root : str Root folder of image files. path_imgidx : str Path to image index file. Needed for partition and shuffling when using .rec source. shuffle : bool Whether to shuffle all images at the start of each iteration or not. Can be slow for HDD. part_index : int Partition index. num_parts : int Total number of partitions. data_name : str Data name for provided symbols. label_name : str Label name for provided symbols. dtype : str Label data type. Default: float32. Other options: int32, int64, float64 last_batch_handle : str, optional How to handle the last batch. This parameter can be 'pad'(default), 'discard' or 'roll_over'. If 'pad', the last batch will be padded with data starting from the begining If 'discard', the last batch will be discarded If 'roll_over', the remaining elements will be rolled over to the next iteration kwargs : ... More arguments for creating augmenter. See mx.image.CreateAugmenter. """ def __init__(self, batch_size, data_shape, label_width=1, path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, data_name='data', label_name='softmax_label', dtype='float32', last_batch_handle='pad', **kwargs): super(ImageIter, self).__init__() assert path_imgrec or path_imglist or (isinstance(imglist, list)) assert dtype in ['int32', 'float32', 'int64', 'float64'], dtype + ' label not supported' num_threads = os.environ.get('MXNET_CPU_WORKER_NTHREADS', 1) logging.info('Using %s threads for decoding...', str(num_threads)) logging.info('Set enviroment variable MXNET_CPU_WORKER_NTHREADS to a' ' larger number to use more threads.') class_name = self.__class__.__name__ if path_imgrec: logging.info('%s: loading recordio %s...', class_name, path_imgrec) if path_imgidx: self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') self.imgidx = list(self.imgrec.keys) else: self.imgrec = recordio.MXRecordIO(path_imgrec, 'r') self.imgidx = None else: self.imgrec = None array_fn = _mx_np.array if is_np_array() else nd.array if path_imglist: logging.info('%s: loading image list %s...', class_name, path_imglist) with open(path_imglist) as fin: imglist = {} imgkeys = [] for line in iter(fin.readline, ''): line = line.strip().split('\t') label = array_fn(line[1:-1], dtype=dtype) key = int(line[0]) imglist[key] = (label, line[-1]) imgkeys.append(key) self.imglist = imglist elif isinstance(imglist, list): logging.info('%s: loading image list...', class_name) result = {} imgkeys = [] index = 1 for img in imglist: key = str(index) index += 1 if len(img) > 2: label = array_fn(img[:-1], dtype=dtype) elif isinstance(img[0], numeric_types): label = array_fn([img[0]], dtype=dtype) else: label = array_fn(img[0], dtype=dtype) result[key] = (label, img[-1]) imgkeys.append(str(key)) self.imglist = result else: self.imglist = None self.path_root = path_root self.check_data_shape(data_shape) self.provide_data = [(data_name, (batch_size,) + data_shape)] if label_width > 1: self.provide_label = [(label_name, (batch_size, label_width))] else: self.provide_label = [(label_name, (batch_size,))] self.batch_size = batch_size self.data_shape = data_shape self.label_width = label_width self.shuffle = shuffle if self.imgrec is None: self.seq = imgkeys elif shuffle or num_parts > 1 or path_imgidx: assert self.imgidx is not None self.seq = self.imgidx else: self.seq = None if num_parts > 1: assert part_index < num_parts N = len(self.seq) C = N // num_parts self.seq = self.seq[part_index * C:(part_index + 1) * C] if aug_list is None: self.auglist = CreateAugmenter(data_shape, **kwargs) else: self.auglist = aug_list self.cur = 0 self._allow_read = True self.last_batch_handle = last_batch_handle self.num_image = len(self.seq) if self.seq is not None else None self._cache_data = None self._cache_label = None self._cache_idx = None self.reset()
[docs] def reset(self): """Resets the iterator to the beginning of the data.""" if self.seq is not None and self.shuffle: random.shuffle(self.seq) if self.last_batch_handle != 'roll_over' or \ self._cache_data is None: if self.imgrec is not None: self.imgrec.reset() self.cur = 0 if self._allow_read is False: self._allow_read = True
[docs] def hard_reset(self): """Resets the iterator and ignore roll over data""" if self.seq is not None and self.shuffle: random.shuffle(self.seq) if self.imgrec is not None: self.imgrec.reset() self.cur = 0 self._allow_read = True self._cache_data = None self._cache_label = None self._cache_idx = None
[docs] def next_sample(self): """Helper function for reading in next sample.""" if self._allow_read is False: raise StopIteration if self.seq is not None: if self.cur < self.num_image: idx = self.seq[self.cur] else: if self.last_batch_handle != 'discard': self.cur = 0 raise StopIteration self.cur += 1 if self.imgrec is not None: s = self.imgrec.read_idx(idx) header, img = recordio.unpack(s) if self.imglist is None: return header.label, img else: return self.imglist[idx][0], img else: label, fname = self.imglist[idx] return label, self.read_image(fname) else: s = self.imgrec.read() if s is None: if self.last_batch_handle != 'discard': self.imgrec.reset() raise StopIteration header, img = recordio.unpack(s) return header.label, img
def _batchify(self, batch_data, batch_label, start=0): """Helper function for batchifying data""" i = start batch_size = self.batch_size try: while i < batch_size: label, s = self.next_sample() data = self.imdecode(s) try: self.check_valid_image(data) except RuntimeError as e: logging.debug('Invalid image, skipping: %s', str(e)) continue data = self.augmentation_transform(data) assert i < batch_size, 'Batch size must be multiples of augmenter output length' batch_data[i] = self.postprocess_data(data) batch_label[i] = label i += 1 except StopIteration: if not i: raise StopIteration return i
[docs] def next(self): """Returns the next batch of data.""" batch_size = self.batch_size c, h, w = self.data_shape # if last batch data is rolled over if self._cache_data is not None: # check both the data and label have values assert self._cache_label is not None, "_cache_label didn't have values" assert self._cache_idx is not None, "_cache_idx didn't have values" batch_data = self._cache_data batch_label = self._cache_label i = self._cache_idx # clear the cache data else: if is_np_array(): zeros_fn = _mx_np.zeros empty_fn = _mx_np.empty else: zeros_fn = nd.zeros empty_fn = nd.empty batch_data = zeros_fn((batch_size, c, h, w)) batch_label = empty_fn(self.provide_label[0][1]) i = self._batchify(batch_data, batch_label) # calculate the padding pad = batch_size - i # handle padding for the last batch if pad != 0: if self.last_batch_handle == 'discard': raise StopIteration # if the option is 'roll_over', throw StopIteration and cache the data if self.last_batch_handle == 'roll_over' and \ self._cache_data is None: self._cache_data = batch_data self._cache_label = batch_label self._cache_idx = i raise StopIteration _ = self._batchify(batch_data, batch_label, i) if self.last_batch_handle == 'pad': self._allow_read = False else: self._cache_data = None self._cache_label = None self._cache_idx = None return io.DataBatch([batch_data], [batch_label], pad=pad)
[docs] def check_data_shape(self, data_shape): """Checks if the input data shape is valid""" if not len(data_shape) == 3: raise ValueError('data_shape should have length 3, with dimensions CxHxW') if not data_shape[0] == 3: raise ValueError('This iterator expects inputs to have 3 channels.')
[docs] def check_valid_image(self, data): """Checks if the input data is valid""" if len(data[0].shape) == 0: raise RuntimeError('Data shape is wrong')
[docs] def imdecode(self, s): """Decodes a string or byte string to an NDArray. See mx.img.imdecode for more details.""" def locate(): """Locate the image file/index if decode fails.""" if self.seq is not None: idx = self.seq[(self.cur % self.num_image) - 1] else: idx = (self.cur % self.num_image) - 1 if self.imglist is not None: _, fname = self.imglist[idx] msg = "filename: {}".format(fname) else: msg = "index: {}".format(idx) return "Broken image " + msg try: img = imdecode(s) except Exception as e: raise RuntimeError("{}, {}".format(locate(), e)) return img
[docs] def read_image(self, fname): """Reads an input image `fname` and returns the decoded raw bytes. Examples -------- >>> dataIter.read_image('Face.jpg') # returns decoded raw bytes. """ with open(os.path.join(self.path_root, fname), 'rb') as fin: img = fin.read() return img
[docs] def augmentation_transform(self, data): """Transforms input data with specified augmentation.""" for aug in self.auglist: data = aug(data) return data
[docs] def postprocess_data(self, datum): """Final postprocessing step before image is loaded into the batch.""" if is_np_array(): return datum.transpose(2, 0, 1) else: return nd.transpose(datum, axes=(2, 0, 1))