# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-locals
"""Definition of various recurrent neural network cells."""
from __future__ import print_function
import bisect
import random
import numpy as np
from ..io import DataIter, DataBatch, DataDesc
from .. import ndarray
def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n',
start_label=0, unknown_token=None):
"""Encode sentences and (optionally) build a mapping
from string tokens to integer indices. Unknown keys
will be added to vocabulary.
Parameters
----------
sentences : list of list of str
A list of sentences to encode. Each sentence
should be a list of string tokens.
vocab : None or dict of str -> int
Optional input Vocabulary
invalid_label : int, default -1
Index for invalid token, like
invalid_key : str, default '\\n'
Key for invalid token. Use '\\n' for end
of sentence by default.
start_label : int
lowest index.
unknown_token: str
Symbol to represent unknown token.
If not specified, unknown token will be skipped.
Returns
-------
result : list of list of int
encoded sentences
vocab : dict of str -> int
result vocabulary
"""
idx = start_label
if vocab is None:
vocab = {invalid_key: invalid_label}
new_vocab = True
else:
new_vocab = False
res = []
for sent in sentences:
coded = []
for word in sent:
if word not in vocab:
assert (new_vocab or unknown_token), "Unknown token %s"%word
if idx == invalid_label:
idx += 1
if unknown_token:
word = unknown_token
vocab[word] = idx
idx += 1
coded.append(vocab[word])
res.append(coded)
return res, vocab
[docs]class BucketSentenceIter(DataIter):
"""Simple bucketing iterator for language model.
The label at each sequence step is the following token
in the sequence.
Parameters
----------
sentences : list of list of int
Encoded sentences.
batch_size : int
Batch size of the data.
invalid_label : int, optional
Key for invalid label, e.g. . The default is -1.
dtype : str, optional
Data type of the encoding. The default data type is 'float32'.
buckets : list of int, optional
Size of the data buckets. Automatically generated if None.
data_name : str, optional
Name of the data. The default name is 'data'.
label_name : str, optional
Name of the label. The default name is 'softmax_label'.
layout : str, optional
Format of data and label. 'NT' means (batch_size, length)
and 'TN' means (length, batch_size).
"""
def __init__(self, sentences, batch_size, buckets=None, invalid_label=-1,
data_name='data', label_name='softmax_label', dtype='float32',
layout='NT'):
super(BucketSentenceIter, self).__init__()
if not buckets:
buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
if j >= batch_size]
buckets.sort()
ndiscard = 0
self.data = [[] for _ in buckets]
valid_buckets = {}
for item in range(len(buckets)):
valid_buckets[item] = 0
for i, sent in enumerate(sentences):
buck = bisect.bisect_left(buckets, len(sent))
valid_buckets[buck] = 1
if buck == len(buckets):
ndiscard += 1
continue
buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
buff[:len(sent)] = sent
self.data[buck].append(buff)
buckets = [j for i, j in enumerate(buckets) if valid_buckets[i] == 1]
self.data = [np.asarray(i, dtype=dtype) for i in self.data if i]
print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)
self.batch_size = batch_size
self.buckets = buckets
self.data_name = data_name
self.label_name = label_name
self.dtype = dtype
self.invalid_label = invalid_label
self.nddata = []
self.ndlabel = []
self.major_axis = layout.find('N')
self.layout = layout
self.default_bucket_key = max(buckets)
if self.major_axis == 0:
self.provide_data = [DataDesc(
name=self.data_name, shape=(batch_size, self.default_bucket_key),
layout=self.layout)]
self.provide_label = [DataDesc(
name=self.label_name, shape=(batch_size, self.default_bucket_key),
layout=self.layout)]
elif self.major_axis == 1:
self.provide_data = [DataDesc(
name=self.data_name, shape=(self.default_bucket_key, batch_size),
layout=self.layout)]
self.provide_label = [DataDesc(
name=self.label_name, shape=(self.default_bucket_key, batch_size),
layout=self.layout)]
else:
raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")
self.idx = []
for i, buck in enumerate(self.data):
self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
self.curr_idx = 0
self.reset()
[docs] def reset(self):
"""Resets the iterator to the beginning of the data."""
self.curr_idx = 0
random.shuffle(self.idx)
for buck in self.data:
np.random.shuffle(buck)
self.nddata = []
self.ndlabel = []
for buck in self.data:
label = np.empty_like(buck)
label[:, :-1] = buck[:, 1:]
label[:, -1] = self.invalid_label
self.nddata.append(ndarray.array(buck, dtype=self.dtype))
self.ndlabel.append(ndarray.array(label, dtype=self.dtype))
[docs] def next(self):
"""Returns the next batch of data."""
if self.curr_idx == len(self.idx):
raise StopIteration
i, j = self.idx[self.curr_idx]
self.curr_idx += 1
if self.major_axis == 1:
data = self.nddata[i][j:j+self.batch_size].T
label = self.ndlabel[i][j:j+self.batch_size].T
else:
data = self.nddata[i][j:j+self.batch_size]
label = self.ndlabel[i][j:j+self.batch_size]
return DataBatch([data], [label], pad=0,
bucket_key=self.buckets[i],
provide_data=[DataDesc(
name=self.data_name, shape=data.shape,
layout=self.layout)],
provide_label=[DataDesc(
name=self.label_name, shape=label.shape,
layout=self.layout)])