Source code for mxnet.recordio

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""Read and write for the RecordIO data format."""
from __future__ import absolute_import
from collections import namedtuple
from multiprocessing import current_process

import ctypes
import struct
import numbers
import numpy as np

from .base import _LIB
from .base import RecordIOHandle
from .base import check_call
from .base import c_str
try:
    import cv2
except ImportError:
    cv2 = None

[docs]class MXRecordIO(object): """Reads/writes `RecordIO` data format, supporting sequential read and write. Examples --------- >>> record = mx.recordio.MXRecordIO('tmp.rec', 'w') >>> for i in range(5): ... record.write('record_%d'%i) >>> record.close() >>> record = mx.recordio.MXRecordIO('tmp.rec', 'r') >>> for i in range(5): ... item = record.read() ... print(item) record_0 record_1 record_2 record_3 record_4 >>> record.close() Parameters ---------- uri : string Path to the record file. flag : string 'w' for write or 'r' for read. """ def __init__(self, uri, flag): self.uri = c_str(uri) self.handle = RecordIOHandle() self.flag = flag self.pid = None self.is_open = False self.open()
[docs] def open(self): """Opens the record file.""" if self.flag == "w": check_call(_LIB.MXRecordIOWriterCreate(self.uri, ctypes.byref(self.handle))) self.writable = True elif self.flag == "r": check_call(_LIB.MXRecordIOReaderCreate(self.uri, ctypes.byref(self.handle))) self.writable = False else: raise ValueError("Invalid flag %s"%self.flag) # pylint: disable=not-callable # It's bug from pylint(astroid). See https://github.com/PyCQA/pylint/issues/1699 self.pid = current_process().pid self.is_open = True
def __del__(self): self.close() def __getstate__(self): """Override pickling behavior.""" # pickling pointer is not allowed is_open = self.is_open self.close() d = dict(self.__dict__) d['is_open'] = is_open uri = self.uri.value try: uri = uri.decode('utf-8') except AttributeError: pass del d['handle'] d['uri'] = uri return d def __setstate__(self, d): """Restore from pickled.""" self.__dict__ = d is_open = d['is_open'] self.is_open = False self.handle = RecordIOHandle() self.uri = c_str(self.uri) if is_open: self.open() def _check_pid(self, allow_reset=False): """Check process id to ensure integrity, reset if in new process.""" # pylint: disable=not-callable # It's bug from pylint(astroid). See https://github.com/PyCQA/pylint/issues/1699 if not self.pid == current_process().pid: if allow_reset: self.reset() else: raise RuntimeError("Forbidden operation in multiple processes")
[docs] def close(self): """Closes the record file.""" if not self.is_open: return if self.writable: check_call(_LIB.MXRecordIOWriterFree(self.handle)) else: check_call(_LIB.MXRecordIOReaderFree(self.handle)) self.is_open = False self.pid = None
[docs] def reset(self): """Resets the pointer to first item. If the record is opened with 'w', this function will truncate the file to empty. Examples --------- >>> record = mx.recordio.MXRecordIO('tmp.rec', 'r') >>> for i in range(2): ... item = record.read() ... print(item) record_0 record_1 >>> record.reset() # Pointer is reset. >>> print(record.read()) # Started reading from start again. record_0 >>> record.close() """ self.close() self.open()
[docs] def write(self, buf): """Inserts a string buffer as a record. Examples --------- >>> record = mx.recordio.MXRecordIO('tmp.rec', 'w') >>> for i in range(5): ... record.write('record_%d'%i) >>> record.close() Parameters ---------- buf : string (python2), bytes (python3) Buffer to write. """ assert self.writable self._check_pid(allow_reset=False) check_call(_LIB.MXRecordIOWriterWriteRecord(self.handle, ctypes.c_char_p(buf), ctypes.c_size_t(len(buf))))
[docs] def read(self): """Returns record as a string. Examples --------- >>> record = mx.recordio.MXRecordIO('tmp.rec', 'r') >>> for i in range(5): ... item = record.read() ... print(item) record_0 record_1 record_2 record_3 record_4 >>> record.close() Returns ---------- buf : string Buffer read. """ assert not self.writable # trying to implicitly read from multiple processes is forbidden, # there's no elegant way to handle unless lock is introduced self._check_pid(allow_reset=False) buf = ctypes.c_char_p() size = ctypes.c_size_t() check_call(_LIB.MXRecordIOReaderReadRecord(self.handle, ctypes.byref(buf), ctypes.byref(size))) if buf: buf = ctypes.cast(buf, ctypes.POINTER(ctypes.c_char*size.value)) return buf.contents.raw else: return None
[docs]class MXIndexedRecordIO(MXRecordIO): """Reads/writes `RecordIO` data format, supporting random access. Examples --------- >>> for i in range(5): ... record.write_idx(i, 'record_%d'%i) >>> record.close() >>> record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'r') >>> record.read_idx(3) record_3 Parameters ---------- idx_path : str Path to the index file. uri : str Path to the record file. Only supports seekable file types. flag : str 'w' for write or 'r' for read. key_type : type Data type for keys. """ def __init__(self, idx_path, uri, flag, key_type=int): self.idx_path = idx_path self.idx = {} self.keys = [] self.key_type = key_type self.fidx = None super(MXIndexedRecordIO, self).__init__(uri, flag) def open(self): super(MXIndexedRecordIO, self).open() self.idx = {} self.keys = [] self.fidx = open(self.idx_path, self.flag) if not self.writable: for line in iter(self.fidx.readline, ''): line = line.strip().split('\t') key = self.key_type(line[0]) self.idx[key] = int(line[1]) self.keys.append(key)
[docs] def close(self): """Closes the record file.""" if not self.is_open: return super(MXIndexedRecordIO, self).close() self.fidx.close()
def __getstate__(self): """Override pickling behavior.""" d = super(MXIndexedRecordIO, self).__getstate__() d['fidx'] = None return d
[docs] def seek(self, idx): """Sets the current read pointer position. This function is internally called by `read_idx(idx)` to find the current reader pointer position. It doesn't return anything.""" assert not self.writable self._check_pid(allow_reset=True) pos = ctypes.c_size_t(self.idx[idx]) check_call(_LIB.MXRecordIOReaderSeek(self.handle, pos))
[docs] def tell(self): """Returns the current position of write head. Examples --------- >>> record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'w') >>> print(record.tell()) 0 >>> for i in range(5): ... record.write_idx(i, 'record_%d'%i) ... print(record.tell()) 16 32 48 64 80 """ assert self.writable pos = ctypes.c_size_t() check_call(_LIB.MXRecordIOWriterTell(self.handle, ctypes.byref(pos))) return pos.value
[docs] def read_idx(self, idx): """Returns the record at given index. Examples --------- >>> record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'w') >>> for i in range(5): ... record.write_idx(i, 'record_%d'%i) >>> record.close() >>> record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'r') >>> record.read_idx(3) record_3 """ self.seek(idx) return self.read()
[docs] def write_idx(self, idx, buf): """Inserts input record at given index. Examples --------- >>> for i in range(5): ... record.write_idx(i, 'record_%d'%i) >>> record.close() Parameters ---------- idx : int Index of a file. buf : Record to write. """ key = self.key_type(idx) pos = self.tell() self.write(buf) self.fidx.write('%s\t%d\n'%(str(key), pos)) self.idx[key] = pos self.keys.append(key)
IRHeader = namedtuple('HEADER', ['flag', 'label', 'id', 'id2']) """An alias for HEADER. Used to store metadata (e.g. labels) accompanying a record. See mxnet.recordio.pack and mxnet.recordio.pack_img for example uses. Parameters ---------- flag : int Available for convenience, can be set arbitrarily. label : float or an array of float Typically used to store label(s) for a record. id: int Usually a unique id representing record. id2: int Higher order bits of the unique id, should be set to 0 (in most cases). """ _IR_FORMAT = 'IfQQ' _IR_SIZE = struct.calcsize(_IR_FORMAT)
[docs]def pack(header, s): """Pack a string into MXImageRecord. Parameters ---------- header : IRHeader Header of the image record. ``header.label`` can be a number or an array. See more detail in ``IRHeader``. s : str Raw image string to be packed. Returns ------- s : str The packed string. Examples -------- >>> label = 4 # label can also be a 1-D array, for example: label = [1,2,3] >>> id = 2574 >>> header = mx.recordio.IRHeader(0, label, id, 0) >>> with open(path, 'r') as file: ... s = file.read() >>> packed_s = mx.recordio.pack(header, s) """ header = IRHeader(*header) if isinstance(header.label, numbers.Number): header = header._replace(flag=0) else: label = np.asarray(header.label, dtype=np.float32) header = header._replace(flag=label.size, label=0) s = label.tostring() + s s = struct.pack(_IR_FORMAT, *header) + s return s
[docs]def unpack(s): """Unpack a MXImageRecord to string. Parameters ---------- s : str String buffer from ``MXRecordIO.read``. Returns ------- header : IRHeader Header of the image record. s : str Unpacked string. Examples -------- >>> record = mx.recordio.MXRecordIO('test.rec', 'r') >>> item = record.read() >>> header, s = mx.recordio.unpack(item) >>> header HEADER(flag=0, label=14.0, id=20129312, id2=0) """ header = IRHeader(*struct.unpack(_IR_FORMAT, s[:_IR_SIZE])) s = s[_IR_SIZE:] if header.flag > 0: header = header._replace(label=np.frombuffer(s, np.float32, header.flag)) s = s[header.flag*4:] return header, s
[docs]def unpack_img(s, iscolor=-1): """Unpack a MXImageRecord to image. Parameters ---------- s : str String buffer from ``MXRecordIO.read``. iscolor : int Image format option for ``cv2.imdecode``. Returns ------- header : IRHeader Header of the image record. img : numpy.ndarray Unpacked image. Examples -------- >>> record = mx.recordio.MXRecordIO('test.rec', 'r') >>> item = record.read() >>> header, img = mx.recordio.unpack_img(item) >>> header HEADER(flag=0, label=14.0, id=20129312, id2=0) >>> img array([[[ 23, 27, 45], [ 28, 32, 50], ..., [ 36, 40, 59], [ 35, 39, 58]], ..., [[ 91, 92, 113], [ 97, 98, 119], ..., [168, 169, 167], [166, 167, 165]]], dtype=uint8) """ header, s = unpack(s) img = np.frombuffer(s, dtype=np.uint8) assert cv2 is not None img = cv2.imdecode(img, iscolor) return header, img
[docs]def pack_img(header, img, quality=95, img_fmt='.jpg'): """Pack an image into ``MXImageRecord``. Parameters ---------- header : IRHeader Header of the image record. ``header.label`` can be a number or an array. See more detail in ``IRHeader``. img : numpy.ndarray Image to be packed. quality : int Quality for JPEG encoding in range 1-100, or compression for PNG encoding in range 1-9. img_fmt : str Encoding of the image (.jpg for JPEG, .png for PNG). Returns ------- s : str The packed string. Examples -------- >>> label = 4 # label can also be a 1-D array, for example: label = [1,2,3] >>> id = 2574 >>> header = mx.recordio.IRHeader(0, label, id, 0) >>> img = cv2.imread('test.jpg') >>> packed_s = mx.recordio.pack_img(header, img) """ assert cv2 is not None jpg_formats = ['.JPG', '.JPEG'] png_formats = ['.PNG'] encode_params = None if img_fmt.upper() in jpg_formats: encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality] elif img_fmt.upper() in png_formats: encode_params = [cv2.IMWRITE_PNG_COMPRESSION, quality] ret, buf = cv2.imencode(img_fmt, img, encode_params) assert ret, 'failed to encode image' return pack(header, buf.tostring())