# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
"""Read images and perform augmentations for object detection."""
from __future__ import absolute_import, print_function
import json
import logging
import random
import warnings
import numpy as np
from ..base import numeric_types
from .. import ndarray as nd
from ..ndarray._internal import _cvcopyMakeBorder as copyMakeBorder
from .. import io
from .image import RandomOrderAug, ColorJitterAug, LightingAug, ColorNormalizeAug
from .image import ResizeAug, ForceResizeAug, CastAug, HueJitterAug, RandomGrayAug
from .image import fixed_crop, ImageIter, Augmenter
[docs]class DetAugmenter(object):
"""Detection base augmenter"""
def __init__(self, **kwargs):
self._kwargs = kwargs
for k, v in self._kwargs.items():
if isinstance(v, nd.NDArray):
v = v.asnumpy()
if isinstance(v, np.ndarray):
v = v.tolist()
self._kwargs[k] = v
[docs] def dumps(self):
"""Saves the Augmenter to string
Returns
-------
str
JSON formatted string that describes the Augmenter.
"""
return json.dumps([self.__class__.__name__.lower(), self._kwargs])
def __call__(self, src, label):
"""Abstract implementation body"""
raise NotImplementedError("Must override implementation.")
[docs]class DetBorrowAug(DetAugmenter):
"""Borrow standard augmenter from image classification.
Which is good once you know label won't be affected after this augmenter.
Parameters
----------
augmenter : mx.image.Augmenter
The borrowed standard augmenter which has no effect on label
"""
def __init__(self, augmenter):
if not isinstance(augmenter, Augmenter):
raise TypeError('Borrowing from invalid Augmenter')
super(DetBorrowAug, self).__init__(augmenter=augmenter.dumps())
self.augmenter = augmenter
def dumps(self):
"""Override the default one to avoid duplicate dump."""
return [self.__class__.__name__.lower(), self.augmenter.dumps()]
def __call__(self, src, label):
"""Augmenter implementation body"""
src = self.augmenter(src)
return (src, label)
[docs]class DetRandomSelectAug(DetAugmenter):
"""Randomly select one augmenter to apply, with chance to skip all.
Parameters
----------
aug_list : list of DetAugmenter
The random selection will be applied to one of the augmenters
skip_prob : float
The probability to skip all augmenters and return input directly
"""
def __init__(self, aug_list, skip_prob=0):
super(DetRandomSelectAug, self).__init__(skip_prob=skip_prob)
if not isinstance(aug_list, (list, tuple)):
aug_list = [aug_list]
for aug in aug_list:
if not isinstance(aug, DetAugmenter):
raise ValueError('Allow DetAugmenter in list only')
if not aug_list:
skip_prob = 1 # disabled
self.aug_list = aug_list
self.skip_prob = skip_prob
def dumps(self):
"""Override default."""
return [self.__class__.__name__.lower(), [x.dumps() for x in self.aug_list]]
def __call__(self, src, label):
"""Augmenter implementation body"""
if random.random() < self.skip_prob:
return (src, label)
else:
random.shuffle(self.aug_list)
return self.aug_list[0](src, label)
[docs]class DetHorizontalFlipAug(DetAugmenter):
"""Random horizontal flipping.
Parameters
----------
p : float
chance [0, 1] to flip
"""
def __init__(self, p):
super(DetHorizontalFlipAug, self).__init__(p=p)
self.p = p
def __call__(self, src, label):
"""Augmenter implementation"""
if random.random() < self.p:
src = nd.flip(src, axis=1)
self._flip_label(label)
return (src, label)
def _flip_label(self, label):
"""Helper function to flip label."""
tmp = 1.0 - label[:, 1]
label[:, 1] = 1.0 - label[:, 3]
label[:, 3] = tmp
[docs]class DetRandomCropAug(DetAugmenter):
"""Random cropping with constraints
Parameters
----------
min_object_covered : float, default=0.1
The cropped area of the image must contain at least this fraction of
any bounding box supplied. The value of this parameter should be non-negative.
In the case of 0, the cropped area does not need to overlap any of the
bounding boxes supplied.
min_eject_coverage : float, default=0.3
The minimum coverage of cropped sample w.r.t its original size. With this
constraint, objects that have marginal area after crop will be discarded.
aspect_ratio_range : tuple of floats, default=(0.75, 1.33)
The cropped area of the image must have an aspect ratio = width / height
within this range.
area_range : tuple of floats, default=(0.05, 1.0)
The cropped area of the image must contain a fraction of the supplied
image within in this range.
max_attempts : int, default=50
Number of attempts at generating a cropped/padded region of the image of the
specified constraints. After max_attempts failures, return the original image.
"""
def __init__(self, min_object_covered=0.1, aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0), min_eject_coverage=0.3, max_attempts=50):
if not isinstance(aspect_ratio_range, (tuple, list)):
assert isinstance(aspect_ratio_range, numeric_types)
logging.info('Using fixed aspect ratio: %s in DetRandomCropAug',
str(aspect_ratio_range))
aspect_ratio_range = (aspect_ratio_range, aspect_ratio_range)
if not isinstance(area_range, (tuple, list)):
assert isinstance(area_range, numeric_types)
logging.info('Using fixed area range: %s in DetRandomCropAug', area_range)
area_range = (area_range, area_range)
super(DetRandomCropAug, self).__init__(min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
min_eject_coverage=min_eject_coverage,
max_attempts=max_attempts)
self.min_object_covered = min_object_covered
self.min_eject_coverage = min_eject_coverage
self.max_attempts = max_attempts
self.aspect_ratio_range = aspect_ratio_range
self.area_range = area_range
self.enabled = False
if (area_range[1] <= 0 or area_range[0] > area_range[1]):
warnings.warn('Skip DetRandomCropAug due to invalid area_range: %s', area_range)
elif (aspect_ratio_range[0] > aspect_ratio_range[1] or aspect_ratio_range[0] <= 0):
warnings.warn('Skip DetRandomCropAug due to invalid aspect_ratio_range: %s',
aspect_ratio_range)
else:
self.enabled = True
def __call__(self, src, label):
"""Augmenter implementation body"""
crop = self._random_crop_proposal(label, src.shape[0], src.shape[1])
if crop:
x, y, w, h, label = crop
src = fixed_crop(src, x, y, w, h, None)
return (src, label)
def _calculate_areas(self, label):
"""Calculate areas for multiple labels"""
heights = np.maximum(0, label[:, 3] - label[:, 1])
widths = np.maximum(0, label[:, 2] - label[:, 0])
return heights * widths
def _intersect(self, label, xmin, ymin, xmax, ymax):
"""Calculate intersect areas, normalized."""
left = np.maximum(label[:, 0], xmin)
right = np.minimum(label[:, 2], xmax)
top = np.maximum(label[:, 1], ymin)
bot = np.minimum(label[:, 3], ymax)
invalid = np.where(np.logical_or(left >= right, top >= bot))[0]
out = label.copy()
out[:, 0] = left
out[:, 1] = top
out[:, 2] = right
out[:, 3] = bot
out[invalid, :] = 0
return out
def _check_satisfy_constraints(self, label, xmin, ymin, xmax, ymax, width, height):
"""Check if constrains are satisfied"""
if (xmax - xmin) * (ymax - ymin) < 2:
return False # only 1 pixel
x1 = float(xmin) / width
y1 = float(ymin) / height
x2 = float(xmax) / width
y2 = float(ymax) / height
object_areas = self._calculate_areas(label[:, 1:])
valid_objects = np.where(object_areas * width * height > 2)[0]
if valid_objects.size < 1:
return False
intersects = self._intersect(label[valid_objects, 1:], x1, y1, x2, y2)
coverages = self._calculate_areas(intersects) / object_areas[valid_objects]
coverages = coverages[np.where(coverages > 0)[0]]
return coverages.size > 0 and np.amin(coverages) > self.min_object_covered
def _update_labels(self, label, crop_box, height, width):
"""Convert labels according to crop box"""
xmin = float(crop_box[0]) / width
ymin = float(crop_box[1]) / height
w = float(crop_box[2]) / width
h = float(crop_box[3]) / height
out = label.copy()
out[:, (1, 3)] -= xmin
out[:, (2, 4)] -= ymin
out[:, (1, 3)] /= w
out[:, (2, 4)] /= h
out[:, 1:5] = np.maximum(0, out[:, 1:5])
out[:, 1:5] = np.minimum(1, out[:, 1:5])
coverage = self._calculate_areas(out[:, 1:]) * w * h / self._calculate_areas(label[:, 1:])
valid = np.logical_and(out[:, 3] > out[:, 1], out[:, 4] > out[:, 2])
valid = np.logical_and(valid, coverage > self.min_eject_coverage)
valid = np.where(valid)[0]
if valid.size < 1:
return None
out = out[valid, :]
return out
def _random_crop_proposal(self, label, height, width):
"""Propose cropping areas"""
from math import sqrt
if not self.enabled or height <= 0 or width <= 0:
return ()
min_area = self.area_range[0] * height * width
max_area = self.area_range[1] * height * width
for _ in range(self.max_attempts):
ratio = random.uniform(*self.aspect_ratio_range)
if ratio <= 0:
continue
h = int(round(sqrt(min_area / ratio)))
max_h = int(round(sqrt(max_area / ratio)))
if round(max_h * ratio) > width:
# find smallest max_h satifying round(max_h * ratio) <= width
max_h = int((width + 0.4999999) / ratio)
if max_h > height:
max_h = height
if h > max_h:
h = max_h
if h < max_h:
# generate random h in range [h, max_h]
h = random.randint(h, max_h)
w = int(round(h * ratio))
assert w <= width
# trying to fix rounding problems
area = w * h
if area < min_area:
h += 1
w = int(round(h * ratio))
area = w * h
if area > max_area:
h -= 1
w = int(round(h * ratio))
area = w * h
if not (min_area <= area <= max_area and 0 <= w <= width and 0 <= h <= height):
continue
y = random.randint(0, max(0, height - h))
x = random.randint(0, max(0, width - w))
if self._check_satisfy_constraints(label, x, y, x + w, y + h, width, height):
new_label = self._update_labels(label, (x, y, w, h), height, width)
if new_label is not None:
return (x, y, w, h, new_label)
return ()
[docs]class DetRandomPadAug(DetAugmenter):
"""Random padding augmenter.
Parameters
----------
aspect_ratio_range : tuple of floats, default=(0.75, 1.33)
The padded area of the image must have an aspect ratio = width / height
within this range.
area_range : tuple of floats, default=(1.0, 3.0)
The padded area of the image must be larger than the original area
max_attempts : int, default=50
Number of attempts at generating a padded region of the image of the
specified constraints. After max_attempts failures, return the original image.
pad_val: float or tuple of float, default=(128, 128, 128)
pixel value to be filled when padding is enabled.
"""
def __init__(self, aspect_ratio_range=(0.75, 1.33), area_range=(1.0, 3.0),
max_attempts=50, pad_val=(128, 128, 128)):
if not isinstance(pad_val, (list, tuple)):
assert isinstance(pad_val, numeric_types)
pad_val = (pad_val)
if not isinstance(aspect_ratio_range, (list, tuple)):
assert isinstance(aspect_ratio_range, numeric_types)
logging.info('Using fixed aspect ratio: %s in DetRandomPadAug',
str(aspect_ratio_range))
aspect_ratio_range = (aspect_ratio_range, aspect_ratio_range)
if not isinstance(area_range, (tuple, list)):
assert isinstance(area_range, numeric_types)
logging.info('Using fixed area range: %s in DetRandomPadAug', area_range)
area_range = (area_range, area_range)
super(DetRandomPadAug, self).__init__(aspect_ratio_range=aspect_ratio_range,
area_range=area_range, max_attempts=max_attempts,
pad_val=pad_val)
self.pad_val = pad_val
self.aspect_ratio_range = aspect_ratio_range
self.area_range = area_range
self.max_attempts = max_attempts
self.enabled = False
if (area_range[1] <= 1.0 or area_range[0] > area_range[1]):
warnings.warn('Skip DetRandomPadAug due to invalid parameters: %s', area_range)
elif (aspect_ratio_range[0] <= 0 or aspect_ratio_range[0] > aspect_ratio_range[1]):
warnings.warn('Skip DetRandomPadAug due to invalid aspect_ratio_range: %s',
aspect_ratio_range)
else:
self.enabled = True
def __call__(self, src, label):
"""Augmenter body"""
height, width, _ = src.shape
pad = self._random_pad_proposal(label, height, width)
if pad:
x, y, w, h, label = pad
src = copyMakeBorder(src, y, h-y-height, x, w-x-width, 16, values=self.pad_val)
return (src, label)
def _update_labels(self, label, pad_box, height, width):
"""Update label according to padding region"""
out = label.copy()
out[:, (1, 3)] = (out[:, (1, 3)] * width + pad_box[0]) / pad_box[2]
out[:, (2, 4)] = (out[:, (2, 4)] * height + pad_box[1]) / pad_box[3]
return out
def _random_pad_proposal(self, label, height, width):
"""Generate random padding region"""
from math import sqrt
if not self.enabled or height <= 0 or width <= 0:
return ()
min_area = self.area_range[0] * height * width
max_area = self.area_range[1] * height * width
for _ in range(self.max_attempts):
ratio = random.uniform(*self.aspect_ratio_range)
if ratio <= 0:
continue
h = int(round(sqrt(min_area / ratio)))
max_h = int(round(sqrt(max_area / ratio)))
if round(h * ratio) < width:
h = int((width + 0.499999) / ratio)
if h < height:
h = height
if h > max_h:
h = max_h
if h < max_h:
h = random.randint(h, max_h)
w = int(round(h * ratio))
if (h - height) < 2 or (w - width) < 2:
continue # marginal padding is not helpful
y = random.randint(0, max(0, h - height))
x = random.randint(0, max(0, w - width))
new_label = self._update_labels(label, (x, y, w, h), height, width)
return (x, y, w, h, new_label)
return ()
def CreateMultiRandCropAugmenter(min_object_covered=0.1, aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0), min_eject_coverage=0.3,
max_attempts=50, skip_prob=0):
"""Helper function to create multiple random crop augmenters.
Parameters
----------
min_object_covered : float or list of float, default=0.1
The cropped area of the image must contain at least this fraction of
any bounding box supplied. The value of this parameter should be non-negative.
In the case of 0, the cropped area does not need to overlap any of the
bounding boxes supplied.
min_eject_coverage : float or list of float, default=0.3
The minimum coverage of cropped sample w.r.t its original size. With this
constraint, objects that have marginal area after crop will be discarded.
aspect_ratio_range : tuple of floats or list of tuple of floats, default=(0.75, 1.33)
The cropped area of the image must have an aspect ratio = width / height
within this range.
area_range : tuple of floats or list of tuple of floats, default=(0.05, 1.0)
The cropped area of the image must contain a fraction of the supplied
image within in this range.
max_attempts : int or list of int, default=50
Number of attempts at generating a cropped/padded region of the image of the
specified constraints. After max_attempts failures, return the original image.
Examples
--------
>>> # An example of creating multiple random crop augmenters
>>> min_object_covered = [0.1, 0.3, 0.5, 0.7, 0.9] # use 5 augmenters
>>> aspect_ratio_range = (0.75, 1.33) # use same range for all augmenters
>>> area_range = [(0.1, 1.0), (0.2, 1.0), (0.2, 1.0), (0.3, 0.9), (0.5, 1.0)]
>>> min_eject_coverage = 0.3
>>> max_attempts = 50
>>> aug = mx.image.det.CreateMultiRandCropAugmenter(min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range, area_range=area_range,
min_eject_coverage=min_eject_coverage, max_attempts=max_attempts,
skip_prob=0)
>>> aug.dumps() # show some details
"""
def align_parameters(params):
"""Align parameters as pairs"""
out_params = []
num = 1
for p in params:
if not isinstance(p, list):
p = [p]
out_params.append(p)
num = max(num, len(p))
# align for each param
for k, p in enumerate(out_params):
if len(p) != num:
assert len(p) == 1
out_params[k] = p * num
return out_params
aligned_params = align_parameters([min_object_covered, aspect_ratio_range, area_range,
min_eject_coverage, max_attempts])
augs = []
for moc, arr, ar, mec, ma in zip(*aligned_params):
augs.append(DetRandomCropAug(min_object_covered=moc, aspect_ratio_range=arr,
area_range=ar, min_eject_coverage=mec, max_attempts=ma))
return DetRandomSelectAug(augs, skip_prob=skip_prob)
def CreateDetAugmenter(data_shape, resize=0, rand_crop=0, rand_pad=0, rand_gray=0,
rand_mirror=False, mean=None, std=None, brightness=0, contrast=0,
saturation=0, pca_noise=0, hue=0, inter_method=2, min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33), area_range=(0.05, 3.0),
min_eject_coverage=0.3, max_attempts=50, pad_val=(127, 127, 127)):
"""Create augmenters for detection.
Parameters
----------
data_shape : tuple of int
Shape for output data
resize : int
Resize shorter edge if larger than 0 at the begining
rand_crop : float
[0, 1], probability to apply random cropping
rand_pad : float
[0, 1], probability to apply random padding
rand_gray : float
[0, 1], probability to convert to grayscale for all channels
rand_mirror : bool
Whether to apply horizontal flip to image with probability 0.5
mean : np.ndarray or None
Mean pixel values for [r, g, b]
std : np.ndarray or None
Standard deviations for [r, g, b]
brightness : float
Brightness jittering range (percent)
contrast : float
Contrast jittering range (percent)
saturation : float
Saturation jittering range (percent)
hue : float
Hue jittering range (percent)
pca_noise : float
Pca noise level (percent)
inter_method : int, default=2(Area-based)
Interpolation method for all resizing operations
Possible values:
0: Nearest Neighbors Interpolation.
1: Bilinear interpolation.
2: Area-based (resampling using pixel area relation). It may be a
preferred method for image decimation, as it gives moire-free
results. But when the image is zoomed, it is similar to the Nearest
Neighbors method. (used by default).
3: Bicubic interpolation over 4x4 pixel neighborhood.
4: Lanczos interpolation over 8x8 pixel neighborhood.
9: Cubic for enlarge, area for shrink, bilinear for others
10: Random select from interpolation method metioned above.
Note:
When shrinking an image, it will generally look best with AREA-based
interpolation, whereas, when enlarging an image, it will generally look best
with Bicubic (slow) or Bilinear (faster but still looks OK).
min_object_covered : float
The cropped area of the image must contain at least this fraction of
any bounding box supplied. The value of this parameter should be non-negative.
In the case of 0, the cropped area does not need to overlap any of the
bounding boxes supplied.
min_eject_coverage : float
The minimum coverage of cropped sample w.r.t its original size. With this
constraint, objects that have marginal area after crop will be discarded.
aspect_ratio_range : tuple of floats
The cropped area of the image must have an aspect ratio = width / height
within this range.
area_range : tuple of floats
The cropped area of the image must contain a fraction of the supplied
image within in this range.
max_attempts : int
Number of attempts at generating a cropped/padded region of the image of the
specified constraints. After max_attempts failures, return the original image.
pad_val: float
Pixel value to be filled when padding is enabled. pad_val will automatically
be subtracted by mean and divided by std if applicable.
Examples
--------
>>> # An example of creating multiple augmenters
>>> augs = mx.image.CreateDetAugmenter(data_shape=(3, 300, 300), rand_crop=0.5,
... rand_pad=0.5, rand_mirror=True, mean=True, brightness=0.125, contrast=0.125,
... saturation=0.125, pca_noise=0.05, inter_method=10, min_object_covered=[0.3, 0.5, 0.9],
... area_range=(0.3, 3.0))
>>> # dump the details
>>> for aug in augs:
... aug.dumps()
"""
auglist = []
if resize > 0:
auglist.append(DetBorrowAug(ResizeAug(resize, inter_method)))
if rand_crop > 0:
crop_augs = CreateMultiRandCropAugmenter(min_object_covered, aspect_ratio_range,
area_range, min_eject_coverage,
max_attempts, skip_prob=(1 - rand_crop))
auglist.append(crop_augs)
if rand_mirror > 0:
auglist.append(DetHorizontalFlipAug(0.5))
# apply random padding as late as possible to save computation
if rand_pad > 0:
pad_aug = DetRandomPadAug(aspect_ratio_range,
(1.0, area_range[1]), max_attempts, pad_val)
auglist.append(DetRandomSelectAug([pad_aug], 1 - rand_pad))
# force resize
auglist.append(DetBorrowAug(ForceResizeAug((data_shape[2], data_shape[1]), inter_method)))
auglist.append(DetBorrowAug(CastAug()))
if brightness or contrast or saturation:
auglist.append(DetBorrowAug(ColorJitterAug(brightness, contrast, saturation)))
if hue:
auglist.append(DetBorrowAug(HueJitterAug(hue)))
if pca_noise > 0:
eigval = np.array([55.46, 4.794, 1.148])
eigvec = np.array([[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203]])
auglist.append(DetBorrowAug(LightingAug(pca_noise, eigval, eigvec)))
if rand_gray > 0:
auglist.append(DetBorrowAug(RandomGrayAug(rand_gray)))
if mean is True:
mean = np.array([123.68, 116.28, 103.53])
elif mean is not None:
assert isinstance(mean, np.ndarray) and mean.shape[0] in [1, 3]
if std is True:
std = np.array([58.395, 57.12, 57.375])
elif std is not None:
assert isinstance(std, np.ndarray) and std.shape[0] in [1, 3]
if mean is not None or std is not None:
auglist.append(DetBorrowAug(ColorNormalizeAug(mean, std)))
return auglist
[docs]class ImageDetIter(ImageIter):
"""Image iterator with a large number of augmentation choices for detection.
Parameters
----------
aug_list : list or None
Augmenter list for generating distorted images
batch_size : int
Number of examples per batch.
data_shape : tuple
Data shape in (channels, height, width) format.
For now, only RGB image with 3 channels is supported.
path_imgrec : str
Path to image record file (.rec).
Created with tools/im2rec.py or bin/im2rec.
path_imglist : str
Path to image list (.lst).
Created with tools/im2rec.py or with custom script.
Format: Tab separated record of index, one or more labels and relative_path_from_root.
imglist: list
A list of images with the label(s).
Each item is a list [imagelabel: float or list of float, imgpath].
path_root : str
Root folder of image files.
path_imgidx : str
Path to image index file. Needed for partition and shuffling when using .rec source.
shuffle : bool
Whether to shuffle all images at the start of each iteration or not.
Can be slow for HDD.
part_index : int
Partition index.
num_parts : int
Total number of partitions.
data_name : str
Data name for provided symbols.
label_name : str
Name for detection labels
last_batch_handle : str, optional
How to handle the last batch.
This parameter can be 'pad'(default), 'discard' or 'roll_over'.
If 'pad', the last batch will be padded with data starting from the begining
If 'discard', the last batch will be discarded
If 'roll_over', the remaining elements will be rolled over to the next iteration
kwargs : ...
More arguments for creating augmenter. See mx.image.CreateDetAugmenter.
"""
def __init__(self, batch_size, data_shape,
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
data_name='data', label_name='label', last_batch_handle='pad', **kwargs):
super(ImageDetIter, self).__init__(batch_size=batch_size, data_shape=data_shape,
path_imgrec=path_imgrec, path_imglist=path_imglist,
path_root=path_root, path_imgidx=path_imgidx,
shuffle=shuffle, part_index=part_index,
num_parts=num_parts, aug_list=[], imglist=imglist,
data_name=data_name, label_name=label_name,
last_batch_handle=last_batch_handle)
if aug_list is None:
self.auglist = CreateDetAugmenter(data_shape, **kwargs)
else:
self.auglist = aug_list
# went through all labels to get the proper label shape
label_shape = self._estimate_label_shape()
self.provide_label = [(label_name, (self.batch_size, label_shape[0], label_shape[1]))]
self.label_shape = label_shape
def _check_valid_label(self, label):
"""Validate label and its shape."""
if len(label.shape) != 2 or label.shape[1] < 5:
msg = "Label with shape (1+, 5+) required, %s received." % str(label)
raise RuntimeError(msg)
valid_label = np.where(np.logical_and(label[:, 0] >= 0, label[:, 3] > label[:, 1],
label[:, 4] > label[:, 2]))[0]
if valid_label.size < 1:
raise RuntimeError('Invalid label occurs.')
def _estimate_label_shape(self):
"""Helper function to estimate label shape"""
max_count = 0
self.reset()
try:
while True:
label, _ = self.next_sample()
label = self._parse_label(label)
max_count = max(max_count, label.shape[0])
except StopIteration:
pass
self.reset()
return (max_count, label.shape[1])
def _parse_label(self, label):
"""Helper function to parse object detection label.
Format for raw label:
n \t k \t ... \t [id \t xmin\t ymin \t xmax \t ymax \t ...] \t [repeat]
where n is the width of header, 2 or larger
k is the width of each object annotation, can be arbitrary, at least 5
"""
if isinstance(label, nd.NDArray):
label = label.asnumpy()
raw = label.ravel()
if raw.size < 7:
raise RuntimeError("Label shape is invalid: " + str(raw.shape))
header_width = int(raw[0])
obj_width = int(raw[1])
if (raw.size - header_width) % obj_width != 0:
msg = "Label shape %s inconsistent with annotation width %d." \
%(str(raw.shape), obj_width)
raise RuntimeError(msg)
out = np.reshape(raw[header_width:], (-1, obj_width))
# remove bad ground-truths
valid = np.where(np.logical_and(out[:, 3] > out[:, 1], out[:, 4] > out[:, 2]))[0]
if valid.size < 1:
raise RuntimeError('Encounter sample with no valid label.')
return out[valid, :]
[docs] def reshape(self, data_shape=None, label_shape=None):
"""Reshape iterator for data_shape or label_shape.
Parameters
----------
data_shape : tuple or None
Reshape the data_shape to the new shape if not None
label_shape : tuple or None
Reshape label shape to new shape if not None
"""
if data_shape is not None:
self.check_data_shape(data_shape)
self.provide_data = [(self.provide_data[0][0], (self.batch_size,) + data_shape)]
self.data_shape = data_shape
if label_shape is not None:
self.check_label_shape(label_shape)
self.provide_label = [(self.provide_label[0][0], (self.batch_size,) + label_shape)]
self.label_shape = label_shape
def _batchify(self, batch_data, batch_label, start=0):
"""Override the helper function for batchifying data"""
i = start
batch_size = self.batch_size
try:
while i < batch_size:
label, s = self.next_sample()
data = self.imdecode(s)
try:
self.check_valid_image([data])
label = self._parse_label(label)
data, label = self.augmentation_transform(data, label)
self._check_valid_label(label)
except RuntimeError as e:
logging.debug('Invalid image, skipping: %s', str(e))
continue
for datum in [data]:
assert i < batch_size, 'Batch size must be multiples of augmenter output length'
batch_data[i] = self.postprocess_data(datum)
num_object = label.shape[0]
batch_label[i][0:num_object] = nd.array(label)
if num_object < batch_label[i].shape[0]:
batch_label[i][num_object:] = -1
i += 1
except StopIteration:
if not i:
raise StopIteration
return i
[docs] def next(self):
"""Override the function for returning next batch."""
batch_size = self.batch_size
c, h, w = self.data_shape
# if last batch data is rolled over
if self._cache_data is not None:
# check both the data and label have values
assert self._cache_label is not None, "_cache_label didn't have values"
assert self._cache_idx is not None, "_cache_idx didn't have values"
batch_data = self._cache_data
batch_label = self._cache_label
i = self._cache_idx
else:
batch_data = nd.zeros((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
batch_label[:] = -1
i = self._batchify(batch_data, batch_label)
# calculate the padding
pad = batch_size - i
# handle padding for the last batch
if pad != 0:
if self.last_batch_handle == 'discard':
raise StopIteration
# if the option is 'roll_over', throw StopIteration and cache the data
elif self.last_batch_handle == 'roll_over' and \
self._cache_data is None:
self._cache_data = batch_data
self._cache_label = batch_label
self._cache_idx = i
raise StopIteration
else:
_ = self._batchify(batch_data, batch_label, i)
if self.last_batch_handle == 'pad':
self._allow_read = False
else:
self._cache_data = None
self._cache_label = None
self._cache_idx = None
return io.DataBatch([batch_data], [batch_label], pad=pad)
[docs] def check_label_shape(self, label_shape):
"""Checks if the new label shape is valid"""
if not len(label_shape) == 2:
raise ValueError('label_shape should have length 2')
if label_shape[0] < self.label_shape[0]:
msg = 'Attempts to reduce label count from %d to %d, not allowed.' \
% (self.label_shape[0], label_shape[0])
raise ValueError(msg)
if label_shape[1] != self.provide_label[0][1][2]:
msg = 'label_shape object width inconsistent: %d vs %d.' \
% (self.provide_label[0][1][2], label_shape[1])
raise ValueError(msg)
[docs] def draw_next(self, color=None, thickness=2, mean=None, std=None, clip=True,
waitKey=None, window_name='draw_next', id2labels=None):
"""Display next image with bounding boxes drawn.
Parameters
----------
color : tuple
Bounding box color in RGB, use None for random color
thickness : int
Bounding box border thickness
mean : True or numpy.ndarray
Compensate for the mean to have better visual effect
std : True or numpy.ndarray
Revert standard deviations
clip : bool
If true, clip to [0, 255] for better visual effect
waitKey : None or int
Hold the window for waitKey milliseconds if set, skip ploting if None
window_name : str
Plot window name if waitKey is set.
id2labels : dict
Mapping of labels id to labels name.
Returns
-------
numpy.ndarray
Examples
--------
>>> # use draw_next to get images with bounding boxes drawn
>>> iterator = mx.image.ImageDetIter(1, (3, 600, 600), path_imgrec='train.rec')
>>> for image in iterator.draw_next(waitKey=None):
... # display image
>>> # or let draw_next display using cv2 module
>>> for image in iterator.draw_next(waitKey=0, window_name='disp'):
... pass
"""
try:
import cv2
except ImportError as e:
warnings.warn('Unable to import cv2, skip drawing: %s', str(e))
return
count = 0
try:
while True:
label, s = self.next_sample()
data = self.imdecode(s)
try:
self.check_valid_image([data])
label = self._parse_label(label)
except RuntimeError as e:
logging.debug('Invalid image, skipping: %s', str(e))
continue
count += 1
data, label = self.augmentation_transform(data, label)
image = data.asnumpy()
# revert color_normalize
if std is True:
std = np.array([58.395, 57.12, 57.375])
elif std is not None:
assert isinstance(std, np.ndarray) and std.shape[0] in [1, 3]
if std is not None:
image *= std
if mean is True:
mean = np.array([123.68, 116.28, 103.53])
elif mean is not None:
assert isinstance(mean, np.ndarray) and mean.shape[0] in [1, 3]
if mean is not None:
image += mean
# swap RGB
image[:, :, (0, 1, 2)] = image[:, :, (2, 1, 0)]
if clip:
image = np.maximum(0, np.minimum(255, image))
if color:
color = color[::-1]
image = image.astype(np.uint8)
height, width, _ = image.shape
for i in range(label.shape[0]):
x1 = int(label[i, 1] * width)
if x1 < 0:
continue
y1 = int(label[i, 2] * height)
x2 = int(label[i, 3] * width)
y2 = int(label[i, 4] * height)
bc = np.random.rand(3) * 255 if not color else color
cv2.rectangle(image, (x1, y1), (x2, y2), bc, thickness)
if id2labels is not None:
cls_id = int(label[i, 0])
if cls_id in id2labels:
cls_name = id2labels[cls_id]
text = "{:s}".format(cls_name)
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
text_height = cv2.getTextSize(text, font, font_scale, 2)[0][1]
tc = (255, 255, 255)
tpos = (x1 + 5, y1 + text_height + 5)
cv2.putText(image, text, tpos, font, font_scale, tc, 2)
if waitKey is not None:
cv2.imshow(window_name, image)
cv2.waitKey(waitKey)
yield image
except StopIteration:
if not count:
return
[docs] def sync_label_shape(self, it, verbose=False):
"""Synchronize label shape with the input iterator. This is useful when
train/validation iterators have different label padding.
Parameters
----------
it : ImageDetIter
The other iterator to synchronize
verbose : bool
Print verbose log if true
Returns
-------
ImageDetIter
The synchronized other iterator, the internal label shape is updated as well.
Examples
--------
>>> train_iter = mx.image.ImageDetIter(32, (3, 300, 300), path_imgrec='train.rec')
>>> val_iter = mx.image.ImageDetIter(32, (3, 300, 300), path.imgrec='val.rec')
>>> train_iter.label_shape
(30, 6)
>>> val_iter.label_shape
(25, 6)
>>> val_iter = train_iter.sync_label_shape(val_iter, verbose=False)
>>> train_iter.label_shape
(30, 6)
>>> val_iter.label_shape
(30, 6)
"""
assert isinstance(it, ImageDetIter), 'Synchronize with invalid iterator.'
train_label_shape = self.label_shape
val_label_shape = it.label_shape
assert train_label_shape[1] == val_label_shape[1], "object width mismatch."
max_count = max(train_label_shape[0], val_label_shape[0])
if max_count > train_label_shape[0]:
self.reshape(None, (max_count, train_label_shape[1]))
if max_count > val_label_shape[0]:
it.reshape(None, (max_count, val_label_shape[1]))
if verbose and max_count > min(train_label_shape[0], val_label_shape[0]):
logging.info('Resized label_shape to (%d, %d).', max_count, train_label_shape[1])
return it