Source code for mxnet.image.detection

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=unused-import
"""Read images and perform augmentations for object detection."""

from __future__ import absolute_import, print_function

import json
import logging
import random
import warnings

import numpy as np

from ..base import numeric_types
from .. import ndarray as nd
from ..ndarray._internal import _cvcopyMakeBorder as copyMakeBorder
from .. import io
from .image import RandomOrderAug, ColorJitterAug, LightingAug, ColorNormalizeAug
from .image import ResizeAug, ForceResizeAug, CastAug, HueJitterAug, RandomGrayAug
from .image import fixed_crop, ImageIter, Augmenter


[docs]class DetAugmenter(object): """Detection base augmenter""" def __init__(self, **kwargs): self._kwargs = kwargs for k, v in self._kwargs.items(): if isinstance(v, nd.NDArray): v = v.asnumpy() if isinstance(v, np.ndarray): v = v.tolist() self._kwargs[k] = v
[docs] def dumps(self): """Saves the Augmenter to string Returns ------- str JSON formatted string that describes the Augmenter. """ return json.dumps([self.__class__.__name__.lower(), self._kwargs])
def __call__(self, src, label): """Abstract implementation body""" raise NotImplementedError("Must override implementation.")
[docs]class DetBorrowAug(DetAugmenter): """Borrow standard augmenter from image classification. Which is good once you know label won't be affected after this augmenter. Parameters ---------- augmenter : mx.image.Augmenter The borrowed standard augmenter which has no effect on label """ def __init__(self, augmenter): if not isinstance(augmenter, Augmenter): raise TypeError('Borrowing from invalid Augmenter') super(DetBorrowAug, self).__init__(augmenter=augmenter.dumps()) self.augmenter = augmenter def dumps(self): """Override the default one to avoid duplicate dump.""" return [self.__class__.__name__.lower(), self.augmenter.dumps()] def __call__(self, src, label): """Augmenter implementation body""" src = self.augmenter(src) return (src, label)
[docs]class DetRandomSelectAug(DetAugmenter): """Randomly select one augmenter to apply, with chance to skip all. Parameters ---------- aug_list : list of DetAugmenter The random selection will be applied to one of the augmenters skip_prob : float The probability to skip all augmenters and return input directly """ def __init__(self, aug_list, skip_prob=0): super(DetRandomSelectAug, self).__init__(skip_prob=skip_prob) if not isinstance(aug_list, (list, tuple)): aug_list = [aug_list] for aug in aug_list: if not isinstance(aug, DetAugmenter): raise ValueError('Allow DetAugmenter in list only') if not aug_list: skip_prob = 1 # disabled self.aug_list = aug_list self.skip_prob = skip_prob def dumps(self): """Override default.""" return [self.__class__.__name__.lower(), [x.dumps() for x in self.aug_list]] def __call__(self, src, label): """Augmenter implementation body""" if random.random() < self.skip_prob: return (src, label) else: random.shuffle(self.aug_list) return self.aug_list[0](src, label)
[docs]class DetHorizontalFlipAug(DetAugmenter): """Random horizontal flipping. Parameters ---------- p : float chance [0, 1] to flip """ def __init__(self, p): super(DetHorizontalFlipAug, self).__init__(p=p) self.p = p def __call__(self, src, label): """Augmenter implementation""" if random.random() < self.p: src = nd.flip(src, axis=1) self._flip_label(label) return (src, label) def _flip_label(self, label): """Helper function to flip label.""" tmp = 1.0 - label[:, 1] label[:, 1] = 1.0 - label[:, 3] label[:, 3] = tmp
[docs]class DetRandomCropAug(DetAugmenter): """Random cropping with constraints Parameters ---------- min_object_covered : float, default=0.1 The cropped area of the image must contain at least this fraction of any bounding box supplied. The value of this parameter should be non-negative. In the case of 0, the cropped area does not need to overlap any of the bounding boxes supplied. min_eject_coverage : float, default=0.3 The minimum coverage of cropped sample w.r.t its original size. With this constraint, objects that have marginal area after crop will be discarded. aspect_ratio_range : tuple of floats, default=(0.75, 1.33) The cropped area of the image must have an aspect ratio = width / height within this range. area_range : tuple of floats, default=(0.05, 1.0) The cropped area of the image must contain a fraction of the supplied image within in this range. max_attempts : int, default=50 Number of attempts at generating a cropped/padded region of the image of the specified constraints. After max_attempts failures, return the original image. """ def __init__(self, min_object_covered=0.1, aspect_ratio_range=(0.75, 1.33), area_range=(0.05, 1.0), min_eject_coverage=0.3, max_attempts=50): if not isinstance(aspect_ratio_range, (tuple, list)): assert isinstance(aspect_ratio_range, numeric_types) logging.info('Using fixed aspect ratio: %s in DetRandomCropAug', str(aspect_ratio_range)) aspect_ratio_range = (aspect_ratio_range, aspect_ratio_range) if not isinstance(area_range, (tuple, list)): assert isinstance(area_range, numeric_types) logging.info('Using fixed area range: %s in DetRandomCropAug', area_range) area_range = (area_range, area_range) super(DetRandomCropAug, self).__init__(min_object_covered=min_object_covered, aspect_ratio_range=aspect_ratio_range, area_range=area_range, min_eject_coverage=min_eject_coverage, max_attempts=max_attempts) self.min_object_covered = min_object_covered self.min_eject_coverage = min_eject_coverage self.max_attempts = max_attempts self.aspect_ratio_range = aspect_ratio_range self.area_range = area_range self.enabled = False if (area_range[1] <= 0 or area_range[0] > area_range[1]): warnings.warn('Skip DetRandomCropAug due to invalid area_range: %s', area_range) elif (aspect_ratio_range[0] > aspect_ratio_range[1] or aspect_ratio_range[0] <= 0): warnings.warn('Skip DetRandomCropAug due to invalid aspect_ratio_range: %s', aspect_ratio_range) else: self.enabled = True def __call__(self, src, label): """Augmenter implementation body""" crop = self._random_crop_proposal(label, src.shape[0], src.shape[1]) if crop: x, y, w, h, label = crop src = fixed_crop(src, x, y, w, h, None) return (src, label) def _calculate_areas(self, label): """Calculate areas for multiple labels""" heights = np.maximum(0, label[:, 3] - label[:, 1]) widths = np.maximum(0, label[:, 2] - label[:, 0]) return heights * widths def _intersect(self, label, xmin, ymin, xmax, ymax): """Calculate intersect areas, normalized.""" left = np.maximum(label[:, 0], xmin) right = np.minimum(label[:, 2], xmax) top = np.maximum(label[:, 1], ymin) bot = np.minimum(label[:, 3], ymax) invalid = np.where(np.logical_or(left >= right, top >= bot))[0] out = label.copy() out[:, 0] = left out[:, 1] = top out[:, 2] = right out[:, 3] = bot out[invalid, :] = 0 return out def _check_satisfy_constraints(self, label, xmin, ymin, xmax, ymax, width, height): """Check if constrains are satisfied""" if (xmax - xmin) * (ymax - ymin) < 2: return False # only 1 pixel x1 = float(xmin) / width y1 = float(ymin) / height x2 = float(xmax) / width y2 = float(ymax) / height object_areas = self._calculate_areas(label[:, 1:]) valid_objects = np.where(object_areas * width * height > 2)[0] if valid_objects.size < 1: return False intersects = self._intersect(label[valid_objects, 1:], x1, y1, x2, y2) coverages = self._calculate_areas(intersects) / object_areas[valid_objects] coverages = coverages[np.where(coverages > 0)[0]] return coverages.size > 0 and np.amin(coverages) > self.min_object_covered def _update_labels(self, label, crop_box, height, width): """Convert labels according to crop box""" xmin = float(crop_box[0]) / width ymin = float(crop_box[1]) / height w = float(crop_box[2]) / width h = float(crop_box[3]) / height out = label.copy() out[:, (1, 3)] -= xmin out[:, (2, 4)] -= ymin out[:, (1, 3)] /= w out[:, (2, 4)] /= h out[:, 1:5] = np.maximum(0, out[:, 1:5]) out[:, 1:5] = np.minimum(1, out[:, 1:5]) coverage = self._calculate_areas(out[:, 1:]) * w * h / self._calculate_areas(label[:, 1:]) valid = np.logical_and(out[:, 3] > out[:, 1], out[:, 4] > out[:, 2]) valid = np.logical_and(valid, coverage > self.min_eject_coverage) valid = np.where(valid)[0] if valid.size < 1: return None out = out[valid, :] return out def _random_crop_proposal(self, label, height, width): """Propose cropping areas""" from math import sqrt if not self.enabled or height <= 0 or width <= 0: return () min_area = self.area_range[0] * height * width max_area = self.area_range[1] * height * width for _ in range(self.max_attempts): ratio = random.uniform(*self.aspect_ratio_range) if ratio <= 0: continue h = int(round(sqrt(min_area / ratio))) max_h = int(round(sqrt(max_area / ratio))) if round(max_h * ratio) > width: # find smallest max_h satifying round(max_h * ratio) <= width max_h = int((width + 0.4999999) / ratio) if max_h > height: max_h = height if h > max_h: h = max_h if h < max_h: # generate random h in range [h, max_h] h = random.randint(h, max_h) w = int(round(h * ratio)) assert w <= width # trying to fix rounding problems area = w * h if area < min_area: h += 1 w = int(round(h * ratio)) area = w * h if area > max_area: h -= 1 w = int(round(h * ratio)) area = w * h if not (min_area <= area <= max_area and 0 <= w <= width and 0 <= h <= height): continue y = random.randint(0, max(0, height - h)) x = random.randint(0, max(0, width - w)) if self._check_satisfy_constraints(label, x, y, x + w, y + h, width, height): new_label = self._update_labels(label, (x, y, w, h), height, width) if new_label is not None: return (x, y, w, h, new_label) return ()
[docs]class DetRandomPadAug(DetAugmenter): """Random padding augmenter. Parameters ---------- aspect_ratio_range : tuple of floats, default=(0.75, 1.33) The padded area of the image must have an aspect ratio = width / height within this range. area_range : tuple of floats, default=(1.0, 3.0) The padded area of the image must be larger than the original area max_attempts : int, default=50 Number of attempts at generating a padded region of the image of the specified constraints. After max_attempts failures, return the original image. pad_val: float or tuple of float, default=(128, 128, 128) pixel value to be filled when padding is enabled. """ def __init__(self, aspect_ratio_range=(0.75, 1.33), area_range=(1.0, 3.0), max_attempts=50, pad_val=(128, 128, 128)): if not isinstance(pad_val, (list, tuple)): assert isinstance(pad_val, numeric_types) pad_val = (pad_val) if not isinstance(aspect_ratio_range, (list, tuple)): assert isinstance(aspect_ratio_range, numeric_types) logging.info('Using fixed aspect ratio: %s in DetRandomPadAug', str(aspect_ratio_range)) aspect_ratio_range = (aspect_ratio_range, aspect_ratio_range) if not isinstance(area_range, (tuple, list)): assert isinstance(area_range, numeric_types) logging.info('Using fixed area range: %s in DetRandomPadAug', area_range) area_range = (area_range, area_range) super(DetRandomPadAug, self).__init__(aspect_ratio_range=aspect_ratio_range, area_range=area_range, max_attempts=max_attempts, pad_val=pad_val) self.pad_val = pad_val self.aspect_ratio_range = aspect_ratio_range self.area_range = area_range self.max_attempts = max_attempts self.enabled = False if (area_range[1] <= 1.0 or area_range[0] > area_range[1]): warnings.warn('Skip DetRandomPadAug due to invalid parameters: %s', area_range) elif (aspect_ratio_range[0] <= 0 or aspect_ratio_range[0] > aspect_ratio_range[1]): warnings.warn('Skip DetRandomPadAug due to invalid aspect_ratio_range: %s', aspect_ratio_range) else: self.enabled = True def __call__(self, src, label): """Augmenter body""" height, width, _ = src.shape pad = self._random_pad_proposal(label, height, width) if pad: x, y, w, h, label = pad src = copyMakeBorder(src, y, h-y-height, x, w-x-width, 16, values=self.pad_val) return (src, label) def _update_labels(self, label, pad_box, height, width): """Update label according to padding region""" out = label.copy() out[:, (1, 3)] = (out[:, (1, 3)] * width + pad_box[0]) / pad_box[2] out[:, (2, 4)] = (out[:, (2, 4)] * height + pad_box[1]) / pad_box[3] return out def _random_pad_proposal(self, label, height, width): """Generate random padding region""" from math import sqrt if not self.enabled or height <= 0 or width <= 0: return () min_area = self.area_range[0] * height * width max_area = self.area_range[1] * height * width for _ in range(self.max_attempts): ratio = random.uniform(*self.aspect_ratio_range) if ratio <= 0: continue h = int(round(sqrt(min_area / ratio))) max_h = int(round(sqrt(max_area / ratio))) if round(h * ratio) < width: h = int((width + 0.499999) / ratio) if h < height: h = height if h > max_h: h = max_h if h < max_h: h = random.randint(h, max_h) w = int(round(h * ratio)) if (h - height) < 2 or (w - width) < 2: continue # marginal padding is not helpful y = random.randint(0, max(0, h - height)) x = random.randint(0, max(0, w - width)) new_label = self._update_labels(label, (x, y, w, h), height, width) return (x, y, w, h, new_label) return ()
def CreateMultiRandCropAugmenter(min_object_covered=0.1, aspect_ratio_range=(0.75, 1.33), area_range=(0.05, 1.0), min_eject_coverage=0.3, max_attempts=50, skip_prob=0): """Helper function to create multiple random crop augmenters. Parameters ---------- min_object_covered : float or list of float, default=0.1 The cropped area of the image must contain at least this fraction of any bounding box supplied. The value of this parameter should be non-negative. In the case of 0, the cropped area does not need to overlap any of the bounding boxes supplied. min_eject_coverage : float or list of float, default=0.3 The minimum coverage of cropped sample w.r.t its original size. With this constraint, objects that have marginal area after crop will be discarded. aspect_ratio_range : tuple of floats or list of tuple of floats, default=(0.75, 1.33) The cropped area of the image must have an aspect ratio = width / height within this range. area_range : tuple of floats or list of tuple of floats, default=(0.05, 1.0) The cropped area of the image must contain a fraction of the supplied image within in this range. max_attempts : int or list of int, default=50 Number of attempts at generating a cropped/padded region of the image of the specified constraints. After max_attempts failures, return the original image. Examples -------- >>> # An example of creating multiple random crop augmenters >>> min_object_covered = [0.1, 0.3, 0.5, 0.7, 0.9] # use 5 augmenters >>> aspect_ratio_range = (0.75, 1.33) # use same range for all augmenters >>> area_range = [(0.1, 1.0), (0.2, 1.0), (0.2, 1.0), (0.3, 0.9), (0.5, 1.0)] >>> min_eject_coverage = 0.3 >>> max_attempts = 50 >>> aug = mx.image.det.CreateMultiRandCropAugmenter(min_object_covered=min_object_covered, aspect_ratio_range=aspect_ratio_range, area_range=area_range, min_eject_coverage=min_eject_coverage, max_attempts=max_attempts, skip_prob=0) >>> aug.dumps() # show some details """ def align_parameters(params): """Align parameters as pairs""" out_params = [] num = 1 for p in params: if not isinstance(p, list): p = [p] out_params.append(p) num = max(num, len(p)) # align for each param for k, p in enumerate(out_params): if len(p) != num: assert len(p) == 1 out_params[k] = p * num return out_params aligned_params = align_parameters([min_object_covered, aspect_ratio_range, area_range, min_eject_coverage, max_attempts]) augs = [] for moc, arr, ar, mec, ma in zip(*aligned_params): augs.append(DetRandomCropAug(min_object_covered=moc, aspect_ratio_range=arr, area_range=ar, min_eject_coverage=mec, max_attempts=ma)) return DetRandomSelectAug(augs, skip_prob=skip_prob) def CreateDetAugmenter(data_shape, resize=0, rand_crop=0, rand_pad=0, rand_gray=0, rand_mirror=False, mean=None, std=None, brightness=0, contrast=0, saturation=0, pca_noise=0, hue=0, inter_method=2, min_object_covered=0.1, aspect_ratio_range=(0.75, 1.33), area_range=(0.05, 3.0), min_eject_coverage=0.3, max_attempts=50, pad_val=(127, 127, 127)): """Create augmenters for detection. Parameters ---------- data_shape : tuple of int Shape for output data resize : int Resize shorter edge if larger than 0 at the begining rand_crop : float [0, 1], probability to apply random cropping rand_pad : float [0, 1], probability to apply random padding rand_gray : float [0, 1], probability to convert to grayscale for all channels rand_mirror : bool Whether to apply horizontal flip to image with probability 0.5 mean : np.ndarray or None Mean pixel values for [r, g, b] std : np.ndarray or None Standard deviations for [r, g, b] brightness : float Brightness jittering range (percent) contrast : float Contrast jittering range (percent) saturation : float Saturation jittering range (percent) hue : float Hue jittering range (percent) pca_noise : float Pca noise level (percent) inter_method : int, default=2(Area-based) Interpolation method for all resizing operations Possible values: 0: Nearest Neighbors Interpolation. 1: Bilinear interpolation. 2: Area-based (resampling using pixel area relation). It may be a preferred method for image decimation, as it gives moire-free results. But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). 3: Bicubic interpolation over 4x4 pixel neighborhood. 4: Lanczos interpolation over 8x8 pixel neighborhood. 9: Cubic for enlarge, area for shrink, bilinear for others 10: Random select from interpolation method metioned above. Note: When shrinking an image, it will generally look best with AREA-based interpolation, whereas, when enlarging an image, it will generally look best with Bicubic (slow) or Bilinear (faster but still looks OK). min_object_covered : float The cropped area of the image must contain at least this fraction of any bounding box supplied. The value of this parameter should be non-negative. In the case of 0, the cropped area does not need to overlap any of the bounding boxes supplied. min_eject_coverage : float The minimum coverage of cropped sample w.r.t its original size. With this constraint, objects that have marginal area after crop will be discarded. aspect_ratio_range : tuple of floats The cropped area of the image must have an aspect ratio = width / height within this range. area_range : tuple of floats The cropped area of the image must contain a fraction of the supplied image within in this range. max_attempts : int Number of attempts at generating a cropped/padded region of the image of the specified constraints. After max_attempts failures, return the original image. pad_val: float Pixel value to be filled when padding is enabled. pad_val will automatically be subtracted by mean and divided by std if applicable. Examples -------- >>> # An example of creating multiple augmenters >>> augs = mx.image.CreateDetAugmenter(data_shape=(3, 300, 300), rand_crop=0.5, ... rand_pad=0.5, rand_mirror=True, mean=True, brightness=0.125, contrast=0.125, ... saturation=0.125, pca_noise=0.05, inter_method=10, min_object_covered=[0.3, 0.5, 0.9], ... area_range=(0.3, 3.0)) >>> # dump the details >>> for aug in augs: ... aug.dumps() """ auglist = [] if resize > 0: auglist.append(DetBorrowAug(ResizeAug(resize, inter_method))) if rand_crop > 0: crop_augs = CreateMultiRandCropAugmenter(min_object_covered, aspect_ratio_range, area_range, min_eject_coverage, max_attempts, skip_prob=(1 - rand_crop)) auglist.append(crop_augs) if rand_mirror > 0: auglist.append(DetHorizontalFlipAug(0.5)) # apply random padding as late as possible to save computation if rand_pad > 0: pad_aug = DetRandomPadAug(aspect_ratio_range, (1.0, area_range[1]), max_attempts, pad_val) auglist.append(DetRandomSelectAug([pad_aug], 1 - rand_pad)) # force resize auglist.append(DetBorrowAug(ForceResizeAug((data_shape[2], data_shape[1]), inter_method))) auglist.append(DetBorrowAug(CastAug())) if brightness or contrast or saturation: auglist.append(DetBorrowAug(ColorJitterAug(brightness, contrast, saturation))) if hue: auglist.append(DetBorrowAug(HueJitterAug(hue))) if pca_noise > 0: eigval = np.array([55.46, 4.794, 1.148]) eigvec = np.array([[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]) auglist.append(DetBorrowAug(LightingAug(pca_noise, eigval, eigvec))) if rand_gray > 0: auglist.append(DetBorrowAug(RandomGrayAug(rand_gray))) if mean is True: mean = np.array([123.68, 116.28, 103.53]) elif mean is not None: assert isinstance(mean, np.ndarray) and mean.shape[0] in [1, 3] if std is True: std = np.array([58.395, 57.12, 57.375]) elif std is not None: assert isinstance(std, np.ndarray) and std.shape[0] in [1, 3] if mean is not None or std is not None: auglist.append(DetBorrowAug(ColorNormalizeAug(mean, std))) return auglist
[docs]class ImageDetIter(ImageIter): """Image iterator with a large number of augmentation choices for detection. Parameters ---------- aug_list : list or None Augmenter list for generating distorted images batch_size : int Number of examples per batch. data_shape : tuple Data shape in (channels, height, width) format. For now, only RGB image with 3 channels is supported. path_imgrec : str Path to image record file (.rec). Created with tools/im2rec.py or bin/im2rec. path_imglist : str Path to image list (.lst). Created with tools/im2rec.py or with custom script. Format: Tab separated record of index, one or more labels and relative_path_from_root. imglist: list A list of images with the label(s). Each item is a list [imagelabel: float or list of float, imgpath]. path_root : str Root folder of image files. path_imgidx : str Path to image index file. Needed for partition and shuffling when using .rec source. shuffle : bool Whether to shuffle all images at the start of each iteration or not. Can be slow for HDD. part_index : int Partition index. num_parts : int Total number of partitions. data_name : str Data name for provided symbols. label_name : str Name for detection labels last_batch_handle : str, optional How to handle the last batch. This parameter can be 'pad'(default), 'discard' or 'roll_over'. If 'pad', the last batch will be padded with data starting from the begining If 'discard', the last batch will be discarded If 'roll_over', the remaining elements will be rolled over to the next iteration kwargs : ... More arguments for creating augmenter. See mx.image.CreateDetAugmenter. """ def __init__(self, batch_size, data_shape, path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, data_name='data', label_name='label', last_batch_handle='pad', **kwargs): super(ImageDetIter, self).__init__(batch_size=batch_size, data_shape=data_shape, path_imgrec=path_imgrec, path_imglist=path_imglist, path_root=path_root, path_imgidx=path_imgidx, shuffle=shuffle, part_index=part_index, num_parts=num_parts, aug_list=[], imglist=imglist, data_name=data_name, label_name=label_name, last_batch_handle=last_batch_handle) if aug_list is None: self.auglist = CreateDetAugmenter(data_shape, **kwargs) else: self.auglist = aug_list # went through all labels to get the proper label shape label_shape = self._estimate_label_shape() self.provide_label = [(label_name, (self.batch_size, label_shape[0], label_shape[1]))] self.label_shape = label_shape def _check_valid_label(self, label): """Validate label and its shape.""" if len(label.shape) != 2 or label.shape[1] < 5: msg = "Label with shape (1+, 5+) required, %s received." % str(label) raise RuntimeError(msg) valid_label = np.where(np.logical_and(label[:, 0] >= 0, label[:, 3] > label[:, 1], label[:, 4] > label[:, 2]))[0] if valid_label.size < 1: raise RuntimeError('Invalid label occurs.') def _estimate_label_shape(self): """Helper function to estimate label shape""" max_count = 0 self.reset() try: while True: label, _ = self.next_sample() label = self._parse_label(label) max_count = max(max_count, label.shape[0]) except StopIteration: pass self.reset() return (max_count, label.shape[1]) def _parse_label(self, label): """Helper function to parse object detection label. Format for raw label: n \t k \t ... \t [id \t xmin\t ymin \t xmax \t ymax \t ...] \t [repeat] where n is the width of header, 2 or larger k is the width of each object annotation, can be arbitrary, at least 5 """ if isinstance(label, nd.NDArray): label = label.asnumpy() raw = label.ravel() if raw.size < 7: raise RuntimeError("Label shape is invalid: " + str(raw.shape)) header_width = int(raw[0]) obj_width = int(raw[1]) if (raw.size - header_width) % obj_width != 0: msg = "Label shape %s inconsistent with annotation width %d." \ %(str(raw.shape), obj_width) raise RuntimeError(msg) out = np.reshape(raw[header_width:], (-1, obj_width)) # remove bad ground-truths valid = np.where(np.logical_and(out[:, 3] > out[:, 1], out[:, 4] > out[:, 2]))[0] if valid.size < 1: raise RuntimeError('Encounter sample with no valid label.') return out[valid, :]
[docs] def reshape(self, data_shape=None, label_shape=None): """Reshape iterator for data_shape or label_shape. Parameters ---------- data_shape : tuple or None Reshape the data_shape to the new shape if not None label_shape : tuple or None Reshape label shape to new shape if not None """ if data_shape is not None: self.check_data_shape(data_shape) self.provide_data = [(self.provide_data[0][0], (self.batch_size,) + data_shape)] self.data_shape = data_shape if label_shape is not None: self.check_label_shape(label_shape) self.provide_label = [(self.provide_label[0][0], (self.batch_size,) + label_shape)] self.label_shape = label_shape
def _batchify(self, batch_data, batch_label, start=0): """Override the helper function for batchifying data""" i = start batch_size = self.batch_size try: while i < batch_size: label, s = self.next_sample() data = self.imdecode(s) try: self.check_valid_image([data]) label = self._parse_label(label) data, label = self.augmentation_transform(data, label) self._check_valid_label(label) except RuntimeError as e: logging.debug('Invalid image, skipping: %s', str(e)) continue for datum in [data]: assert i < batch_size, 'Batch size must be multiples of augmenter output length' batch_data[i] = self.postprocess_data(datum) num_object = label.shape[0] batch_label[i][0:num_object] = nd.array(label) if num_object < batch_label[i].shape[0]: batch_label[i][num_object:] = -1 i += 1 except StopIteration: if not i: raise StopIteration return i
[docs] def next(self): """Override the function for returning next batch.""" batch_size = self.batch_size c, h, w = self.data_shape # if last batch data is rolled over if self._cache_data is not None: # check both the data and label have values assert self._cache_label is not None, "_cache_label didn't have values" assert self._cache_idx is not None, "_cache_idx didn't have values" batch_data = self._cache_data batch_label = self._cache_label i = self._cache_idx else: batch_data = nd.zeros((batch_size, c, h, w)) batch_label = nd.empty(self.provide_label[0][1]) batch_label[:] = -1 i = self._batchify(batch_data, batch_label) # calculate the padding pad = batch_size - i # handle padding for the last batch if pad != 0: if self.last_batch_handle == 'discard': raise StopIteration # if the option is 'roll_over', throw StopIteration and cache the data elif self.last_batch_handle == 'roll_over' and \ self._cache_data is None: self._cache_data = batch_data self._cache_label = batch_label self._cache_idx = i raise StopIteration else: _ = self._batchify(batch_data, batch_label, i) if self.last_batch_handle == 'pad': self._allow_read = False else: self._cache_data = None self._cache_label = None self._cache_idx = None return io.DataBatch([batch_data], [batch_label], pad=pad)
[docs] def augmentation_transform(self, data, label): # pylint: disable=arguments-differ """Override Transforms input data with specified augmentations.""" for aug in self.auglist: data, label = aug(data, label) return (data, label)
[docs] def check_label_shape(self, label_shape): """Checks if the new label shape is valid""" if not len(label_shape) == 2: raise ValueError('label_shape should have length 2') if label_shape[0] < self.label_shape[0]: msg = 'Attempts to reduce label count from %d to %d, not allowed.' \ % (self.label_shape[0], label_shape[0]) raise ValueError(msg) if label_shape[1] != self.provide_label[0][1][2]: msg = 'label_shape object width inconsistent: %d vs %d.' \ % (self.provide_label[0][1][2], label_shape[1]) raise ValueError(msg)
[docs] def draw_next(self, color=None, thickness=2, mean=None, std=None, clip=True, waitKey=None, window_name='draw_next', id2labels=None): """Display next image with bounding boxes drawn. Parameters ---------- color : tuple Bounding box color in RGB, use None for random color thickness : int Bounding box border thickness mean : True or numpy.ndarray Compensate for the mean to have better visual effect std : True or numpy.ndarray Revert standard deviations clip : bool If true, clip to [0, 255] for better visual effect waitKey : None or int Hold the window for waitKey milliseconds if set, skip ploting if None window_name : str Plot window name if waitKey is set. id2labels : dict Mapping of labels id to labels name. Returns ------- numpy.ndarray Examples -------- >>> # use draw_next to get images with bounding boxes drawn >>> iterator = mx.image.ImageDetIter(1, (3, 600, 600), path_imgrec='train.rec') >>> for image in iterator.draw_next(waitKey=None): ... # display image >>> # or let draw_next display using cv2 module >>> for image in iterator.draw_next(waitKey=0, window_name='disp'): ... pass """ try: import cv2 except ImportError as e: warnings.warn('Unable to import cv2, skip drawing: %s', str(e)) return count = 0 try: while True: label, s = self.next_sample() data = self.imdecode(s) try: self.check_valid_image([data]) label = self._parse_label(label) except RuntimeError as e: logging.debug('Invalid image, skipping: %s', str(e)) continue count += 1 data, label = self.augmentation_transform(data, label) image = data.asnumpy() # revert color_normalize if std is True: std = np.array([58.395, 57.12, 57.375]) elif std is not None: assert isinstance(std, np.ndarray) and std.shape[0] in [1, 3] if std is not None: image *= std if mean is True: mean = np.array([123.68, 116.28, 103.53]) elif mean is not None: assert isinstance(mean, np.ndarray) and mean.shape[0] in [1, 3] if mean is not None: image += mean # swap RGB image[:, :, (0, 1, 2)] = image[:, :, (2, 1, 0)] if clip: image = np.maximum(0, np.minimum(255, image)) if color: color = color[::-1] image = image.astype(np.uint8) height, width, _ = image.shape for i in range(label.shape[0]): x1 = int(label[i, 1] * width) if x1 < 0: continue y1 = int(label[i, 2] * height) x2 = int(label[i, 3] * width) y2 = int(label[i, 4] * height) bc = np.random.rand(3) * 255 if not color else color cv2.rectangle(image, (x1, y1), (x2, y2), bc, thickness) if id2labels is not None: cls_id = int(label[i, 0]) if cls_id in id2labels: cls_name = id2labels[cls_id] text = "{:s}".format(cls_name) font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 text_height = cv2.getTextSize(text, font, font_scale, 2)[0][1] tc = (255, 255, 255) tpos = (x1 + 5, y1 + text_height + 5) cv2.putText(image, text, tpos, font, font_scale, tc, 2) if waitKey is not None: cv2.imshow(window_name, image) cv2.waitKey(waitKey) yield image except StopIteration: if not count: return
[docs] def sync_label_shape(self, it, verbose=False): """Synchronize label shape with the input iterator. This is useful when train/validation iterators have different label padding. Parameters ---------- it : ImageDetIter The other iterator to synchronize verbose : bool Print verbose log if true Returns ------- ImageDetIter The synchronized other iterator, the internal label shape is updated as well. Examples -------- >>> train_iter = mx.image.ImageDetIter(32, (3, 300, 300), path_imgrec='train.rec') >>> val_iter = mx.image.ImageDetIter(32, (3, 300, 300), path.imgrec='val.rec') >>> train_iter.label_shape (30, 6) >>> val_iter.label_shape (25, 6) >>> val_iter = train_iter.sync_label_shape(val_iter, verbose=False) >>> train_iter.label_shape (30, 6) >>> val_iter.label_shape (30, 6) """ assert isinstance(it, ImageDetIter), 'Synchronize with invalid iterator.' train_label_shape = self.label_shape val_label_shape = it.label_shape assert train_label_shape[1] == val_label_shape[1], "object width mismatch." max_count = max(train_label_shape[0], val_label_shape[0]) if max_count > train_label_shape[0]: self.reshape(None, (max_count, train_label_shape[1])) if max_count > val_label_shape[0]: it.reshape(None, (max_count, val_label_shape[1])) if verbose and max_count > min(train_label_shape[0], val_label_shape[0]): logging.info('Resized label_shape to (%d, %d).', max_count, train_label_shape[1]) return it