# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=consider-iterating-dictionary
# pylint: disable=super-init-not-called
"""Text token embeddings."""
from __future__ import absolute_import
from __future__ import print_function
import io
import logging
import os
import tarfile
import warnings
import zipfile
from . import _constants as C
from . import vocab
from ... import ndarray as nd
from ... import registry
from ... import base
[docs]def register(embedding_cls):
"""Registers a new token embedding.
Once an embedding is registered, we can create an instance of this embedding with
:func:`~mxnet.contrib.text.embedding.create`.
Examples
--------
>>> @mxnet.contrib.text.embedding.register
... class MyTextEmbed(mxnet.contrib.text.embedding._TokenEmbedding):
... def __init__(self, pretrained_file_name='my_pretrain_file'):
... pass
>>> embed = mxnet.contrib.text.embedding.create('MyTokenEmbed')
>>> print(type(embed))
"""
register_text_embedding = registry.get_register_func(_TokenEmbedding, 'token embedding')
return register_text_embedding(embedding_cls)
[docs]def create(embedding_name, **kwargs):
"""Creates an instance of token embedding.
Creates a token embedding instance by loading embedding vectors from an externally hosted
pre-trained token embedding file, such as those of GloVe and FastText. To get all the valid
`embedding_name` and `pretrained_file_name`, use
`mxnet.contrib.text.embedding.get_pretrained_file_names()`.
Parameters
----------
embedding_name : str
The token embedding name (case-insensitive).
Returns
-------
An instance of `mxnet.contrib.text.glossary._TokenEmbedding`:
A token embedding instance that loads embedding vectors from an externally hosted
pre-trained token embedding file.
"""
create_text_embedding = registry.get_create_func(_TokenEmbedding, 'token embedding')
return create_text_embedding(embedding_name, **kwargs)
[docs]def get_pretrained_file_names(embedding_name=None):
"""Get valid token embedding names and their pre-trained file names.
To load token embedding vectors from an externally hosted pre-trained token embedding file,
such as those of GloVe and FastText, one should use
`mxnet.contrib.text.embedding.create(embedding_name, pretrained_file_name)`.
This method returns all the valid names of `pretrained_file_name` for the specified
`embedding_name`. If `embedding_name` is set to None, this method returns all the valid
names of `embedding_name` with their associated `pretrained_file_name`.
Parameters
----------
embedding_name : str or None, default None
The pre-trained token embedding name.
Returns
-------
dict or list:
A list of all the valid pre-trained token embedding file names (`pretrained_file_name`)
for the specified token embedding name (`embedding_name`). If the text embeding name is
set to None, returns a dict mapping each valid token embedding name to a list of valid
pre-trained files (`pretrained_file_name`). They can be plugged into
`mxnet.contrib.text.embedding.create(embedding_name,
pretrained_file_name)`.
"""
text_embedding_reg = registry.get_registry(_TokenEmbedding)
if embedding_name is not None:
if embedding_name not in text_embedding_reg:
raise KeyError('Cannot find `embedding_name` %s. Use '
'`get_pretrained_file_names('
'embedding_name=None).keys()` to get all the valid embedding '
'names.' % embedding_name)
return list(text_embedding_reg[embedding_name].pretrained_file_name_sha1.keys())
else:
return {embedding_name: list(embedding_cls.pretrained_file_name_sha1.keys())
for embedding_name, embedding_cls in registry.get_registry(_TokenEmbedding).items()}
class _TokenEmbedding(vocab.Vocabulary):
"""Token embedding base class.
To load token embeddings from an externally hosted pre-trained token embedding file, such as
those of GloVe and FastText, use
:func:`~mxnet.contrib.text.embedding.create(embedding_name, pretrained_file_name)`.
To get all the available `embedding_name` and `pretrained_file_name`, use
:func:`~mxnet.contrib.text.embedding.get_pretrained_file_names()`.
Alternatively, to load embedding vectors from a custom pre-trained token embedding file, use
:class:`~mxnet.contrib.text.embedding.CustomEmbedding`.
Moreover, to load composite embedding vectors, such as to concatenate embedding vectors, use
:class:`~mxnet.contrib.text.embedding.CompositeEmbedding`.
For every unknown token, if its representation `self.unknown_token` is encountered in the
pre-trained token embedding file, index 0 of `self.idx_to_vec` maps to the pre-trained token
embedding vector loaded from the file; otherwise, index 0 of `self.idx_to_vec` maps to the
token embedding vector initialized by `init_unknown_vec`.
If a token is encountered multiple times in the pre-trained token embedding file, only the
first-encountered token embedding vector will be loaded and the rest will be skipped.
The indexed tokens in a text token embedding may come from a vocabulary or from the loaded
embedding vectors. In the former case, only the indexed tokens in a vocabulary are associated
with the loaded embedding vectors, such as loaded from a pre-trained token embedding file. In
the later case, all the tokens from the loaded embedding vectors, such as loaded from a
pre-trained token embedding file, are taken as the indexed tokens of the embedding.
Attributes
----------
token_to_idx : dict mapping str to int
A dict mapping each token to its index integer.
idx_to_token : list of strs
A list of indexed tokens where the list indices and the token indices are aligned.
unknown_token : hashable object
The representation for any unknown token. In other words, any unknown token will be indexed
as the same representation.
reserved_tokens : list of strs or None
A list of reserved tokens that will always be indexed.
vec_len : int
The length of the embedding vector for each token.
idx_to_vec : mxnet.ndarray.NDArray
For all the indexed tokens in this embedding, this NDArray maps each token's index to an
embedding vector. The largest valid index maps to the initialized embedding vector for every
reserved token, such as an unknown_token token and a padding token.
"""
def __init__(self, **kwargs):
super(_TokenEmbedding, self).__init__(**kwargs)
@classmethod
def _get_download_file_name(cls, pretrained_file_name):
return pretrained_file_name
@classmethod
def _get_pretrained_file_url(cls, pretrained_file_name):
repo_url = os.environ.get('MXNET_GLUON_REPO', C.APACHE_REPO_URL)
embedding_cls = cls.__name__.lower()
url_format = '{repo_url}gluon/embeddings/{cls}/{file_name}'
return url_format.format(repo_url=repo_url, cls=embedding_cls,
file_name=cls._get_download_file_name(pretrained_file_name))
@classmethod
def _get_pretrained_file(cls, embedding_root, pretrained_file_name):
from ...gluon.utils import check_sha1, download
embedding_cls = cls.__name__.lower()
embedding_root = os.path.expanduser(embedding_root)
url = cls._get_pretrained_file_url(pretrained_file_name)
embedding_dir = os.path.join(embedding_root, embedding_cls)
pretrained_file_path = os.path.join(embedding_dir, pretrained_file_name)
downloaded_file = os.path.basename(url)
downloaded_file_path = os.path.join(embedding_dir, downloaded_file)
expected_file_hash = cls.pretrained_file_name_sha1[pretrained_file_name]
if hasattr(cls, 'pretrained_archive_name_sha1'):
expected_downloaded_hash = \
cls.pretrained_archive_name_sha1[downloaded_file]
else:
expected_downloaded_hash = expected_file_hash
if not os.path.exists(pretrained_file_path) \
or not check_sha1(pretrained_file_path, expected_file_hash):
download(url, downloaded_file_path, sha1_hash=expected_downloaded_hash)
ext = os.path.splitext(downloaded_file)[1]
if ext == '.zip':
with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
zf.extractall(embedding_dir)
elif ext == '.gz':
with tarfile.open(downloaded_file_path, 'r:gz') as tar:
tar.extractall(path=embedding_dir)
return pretrained_file_path
def _load_embedding(self, pretrained_file_path, elem_delim, init_unknown_vec, encoding='utf8'):
"""Load embedding vectors from the pre-trained token embedding file.
For every unknown token, if its representation `self.unknown_token` is encountered in the
pre-trained token embedding file, index 0 of `self.idx_to_vec` maps to the pre-trained token
embedding vector loaded from the file; otherwise, index 0 of `self.idx_to_vec` maps to the
text embedding vector initialized by `init_unknown_vec`.
If a token is encountered multiple times in the pre-trained text embedding file, only the
first-encountered token embedding vector will be loaded and the rest will be skipped.
"""
pretrained_file_path = os.path.expanduser(pretrained_file_path)
if not os.path.isfile(pretrained_file_path):
raise ValueError('`pretrained_file_path` must be a valid path to '
'the pre-trained token embedding file.')
logging.info('Loading pre-trained token embedding vectors from %s', pretrained_file_path)
vec_len = None
all_elems = []
tokens = set()
loaded_unknown_vec = None
line_num = 0
with io.open(pretrained_file_path, 'r', encoding=encoding) as f:
for line in f:
line_num += 1
elems = line.rstrip().split(elem_delim)
assert len(elems) > 1, 'At line %d of the pre-trained text embedding file: the ' \
'data format of the pre-trained token embedding file %s ' \
'is unexpected.' % (line_num, pretrained_file_path)
token, elems = elems[0], [float(i) for i in elems[1:]]
if token == self.unknown_token and loaded_unknown_vec is None:
loaded_unknown_vec = elems
tokens.add(self.unknown_token)
elif token in tokens:
warnings.warn('At line %d of the pre-trained token embedding file: the '
'embedding vector for token %s has been loaded and a duplicate '
'embedding for the same token is seen and skipped.' %
(line_num, token))
elif len(elems) == 1:
warnings.warn('At line %d of the pre-trained text embedding file: token %s '
'with 1-dimensional vector %s is likely a header and is '
'skipped.' % (line_num, token, elems))
else:
if vec_len is None:
vec_len = len(elems)
# Reserve a vector slot for the unknown token at the very beggining because
# the unknown index is 0.
all_elems.extend([0] * vec_len)
else:
assert len(elems) == vec_len, \
'At line %d of the pre-trained token embedding file: the dimension ' \
'of token %s is %d but the dimension of previous tokens is %d. ' \
'Dimensions of all the tokens must be the same.' \
% (line_num, token, len(elems), vec_len)
all_elems.extend(elems)
self._idx_to_token.append(token)
self._token_to_idx[token] = len(self._idx_to_token) - 1
tokens.add(token)
self._vec_len = vec_len
self._idx_to_vec = nd.array(all_elems).reshape((-1, self.vec_len))
if loaded_unknown_vec is None:
self._idx_to_vec[C.UNKNOWN_IDX] = init_unknown_vec(shape=self.vec_len)
else:
self._idx_to_vec[C.UNKNOWN_IDX] = nd.array(loaded_unknown_vec)
def _index_tokens_from_vocabulary(self, vocabulary):
self._token_to_idx = vocabulary.token_to_idx.copy() \
if vocabulary.token_to_idx is not None else None
self._idx_to_token = vocabulary.idx_to_token[:] \
if vocabulary.idx_to_token is not None else None
self._unknown_token = vocabulary.unknown_token
self._reserved_tokens = vocabulary.reserved_tokens[:] \
if vocabulary.reserved_tokens is not None else None
def _set_idx_to_vec_by_embeddings(self, token_embeddings, vocab_len, vocab_idx_to_token):
"""Sets the mapping between token indices and token embedding vectors.
Parameters
----------
token_embeddings : instance or list `mxnet.contrib.text.embedding._TokenEmbedding`
One or multiple pre-trained token embeddings to load. If it is a list of multiple
embeddings, these embedding vectors will be concatenated for each token.
vocab_len : int
Length of vocabulary whose tokens are indexed in the token embedding.
vocab_idx_to_token: list of str
A list of indexed tokens in the vocabulary. These tokens are indexed in the token
embedding.
"""
new_vec_len = sum(embed.vec_len for embed in token_embeddings)
new_idx_to_vec = nd.zeros(shape=(vocab_len, new_vec_len))
col_start = 0
# Concatenate all the embedding vectors in token_embeddings.
for embed in token_embeddings:
col_end = col_start + embed.vec_len
# Cancatenate vectors of the unknown token.
new_idx_to_vec[0, col_start:col_end] = embed.idx_to_vec[0]
new_idx_to_vec[1:, col_start:col_end] = embed.get_vecs_by_tokens(vocab_idx_to_token[1:])
col_start = col_end
self._vec_len = new_vec_len
self._idx_to_vec = new_idx_to_vec
def _build_embedding_for_vocabulary(self, vocabulary):
if vocabulary is not None:
assert isinstance(vocabulary, vocab.Vocabulary), \
'The argument `vocabulary` must be an instance of ' \
'mxnet.contrib.text.vocab.Vocabulary.'
# Set _idx_to_vec so that indices of tokens from vocabulary are associated with the
# loaded token embedding vectors.
self._set_idx_to_vec_by_embeddings([self], len(vocabulary), vocabulary.idx_to_token)
# Index tokens from vocabulary.
self._index_tokens_from_vocabulary(vocabulary)
@property
def vec_len(self):
return self._vec_len
@property
def idx_to_vec(self):
return self._idx_to_vec
def get_vecs_by_tokens(self, tokens, lower_case_backup=False):
"""Look up embedding vectors of tokens.
Parameters
----------
tokens : str or list of strs
A token or a list of tokens.
lower_case_backup : bool, default False
If False, each token in the original case will be looked up; if True, each token in the
original case will be looked up first, if not found in the keys of the property
`token_to_idx`, the token in the lower case will be looked up.
Returns
-------
mxnet.ndarray.NDArray:
The embedding vector(s) of the token(s). According to numpy conventions, if `tokens` is
a string, returns a 1-D NDArray of shape `self.vec_len`; if `tokens` is a list of
strings, returns a 2-D NDArray of shape=(len(tokens), self.vec_len).
"""
to_reduce = False
if not isinstance(tokens, list):
tokens = [tokens]
to_reduce = True
if not lower_case_backup:
indices = [self.token_to_idx.get(token, C.UNKNOWN_IDX) for token in tokens]
else:
indices = [self.token_to_idx[token] if token in self.token_to_idx
else self.token_to_idx.get(token.lower(), C.UNKNOWN_IDX)
for token in tokens]
vecs = nd.Embedding(nd.array(indices), self.idx_to_vec, self.idx_to_vec.shape[0],
self.idx_to_vec.shape[1])
return vecs[0] if to_reduce else vecs
def update_token_vectors(self, tokens, new_vectors):
"""Updates embedding vectors for tokens.
Parameters
----------
tokens : str or a list of strs
A token or a list of tokens whose embedding vector are to be updated.
new_vectors : mxnet.ndarray.NDArray
An NDArray to be assigned to the embedding vectors of `tokens`. Its length must be equal
to the number of `tokens` and its width must be equal to the dimension of embeddings of
the glossary. If `tokens` is a singleton, it must be 1-D or 2-D. If `tokens` is a list
of multiple strings, it must be 2-D.
"""
assert self.idx_to_vec is not None, 'The property `idx_to_vec` has not been properly set.'
if not isinstance(tokens, list) or len(tokens) == 1:
assert isinstance(new_vectors, nd.NDArray) and len(new_vectors.shape) in [1, 2], \
'`new_vectors` must be a 1-D or 2-D NDArray if `tokens` is a singleton.'
if not isinstance(tokens, list):
tokens = [tokens]
if len(new_vectors.shape) == 1:
new_vectors = new_vectors.expand_dims(0)
else:
assert isinstance(new_vectors, nd.NDArray) and len(new_vectors.shape) == 2, \
'`new_vectors` must be a 2-D NDArray if `tokens` is a list of multiple strings.'
assert new_vectors.shape == (len(tokens), self.vec_len), \
'The length of new_vectors must be equal to the number of tokens and the width of' \
'new_vectors must be equal to the dimension of embeddings of the glossary.'
indices = []
for token in tokens:
if token in self.token_to_idx:
indices.append(self.token_to_idx[token])
else:
raise ValueError('Token %s is unknown. To update the embedding vector for an '
'unknown token, please specify it explicitly as the '
'`unknown_token` %s in `tokens`. This is to avoid unintended '
'updates.' % (token, self.idx_to_token[C.UNKNOWN_IDX]))
self._idx_to_vec[nd.array(indices)] = new_vectors
@classmethod
def _check_pretrained_file_names(cls, pretrained_file_name):
"""Checks if a pre-trained token embedding file name is valid.
Parameters
----------
pretrained_file_name : str
The pre-trained token embedding file.
"""
embedding_name = cls.__name__.lower()
if pretrained_file_name not in cls.pretrained_file_name_sha1:
raise KeyError('Cannot find pretrained file %s for token embedding %s. Valid '
'pretrained files for embedding %s: %s' %
(pretrained_file_name, embedding_name, embedding_name,
', '.join(cls.pretrained_file_name_sha1.keys())))
@register
[docs]class GloVe(_TokenEmbedding):
"""The GloVe word embedding.
GloVe is an unsupervised learning algorithm for obtaining vector representations for words.
Training is performed on aggregated global word-word co-occurrence statistics from a corpus, and
the resulting representations showcase interesting linear substructures of the word vector
space. (Source from https://nlp.stanford.edu/projects/glove/)
References
----------
GloVe: Global Vectors for Word Representation.
Jeffrey Pennington, Richard Socher, and Christopher D. Manning.
https://nlp.stanford.edu/pubs/glove.pdf
Website:
https://nlp.stanford.edu/projects/glove/
To get the updated URLs to the externally hosted pre-trained token embedding
files, visit https://nlp.stanford.edu/projects/glove/
License for pre-trained embeddings:
https://fedoraproject.org/wiki/Licensing/PDDL
Parameters
----------
pretrained_file_name : str, default 'glove.840B.300d.txt'
The name of the pre-trained token embedding file.
embedding_root : str, default $MXNET_HOME/embeddings
The root directory for storing embedding-related files.
init_unknown_vec : callback
The callback used to initialize the embedding vector for the unknown token.
vocabulary : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None
It contains the tokens to index. Each indexed token will be associated with the loaded
embedding vectors, such as loaded from a pre-trained token embedding file. If None, all the
tokens from the loaded embedding vectors, such as loaded from a pre-trained token embedding
file, will be indexed.
"""
# Map a pre-trained token embedding archive file and its SHA-1 hash.
pretrained_archive_name_sha1 = C.GLOVE_PRETRAINED_FILE_SHA1
# Map a pre-trained token embedding file and its SHA-1 hash.
pretrained_file_name_sha1 = C.GLOVE_PRETRAINED_ARCHIVE_SHA1
@classmethod
def _get_download_file_name(cls, pretrained_file_name):
# Map a pre-trained embedding file to its archive to download.
src_archive = {archive.split('.')[1]: archive for archive in
GloVe.pretrained_archive_name_sha1.keys()}
archive = src_archive[pretrained_file_name.split('.')[1]]
return archive
def __init__(self, pretrained_file_name='glove.840B.300d.txt',
embedding_root=os.path.join(base.data_dir(), 'embeddings'),
init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
GloVe._check_pretrained_file_names(pretrained_file_name)
super(GloVe, self).__init__(**kwargs)
pretrained_file_path = GloVe._get_pretrained_file(embedding_root, pretrained_file_name)
self._load_embedding(pretrained_file_path, ' ', init_unknown_vec)
if vocabulary is not None:
self._build_embedding_for_vocabulary(vocabulary)
@register
[docs]class FastText(_TokenEmbedding):
"""The fastText word embedding.
FastText is an open-source, free, lightweight library that allows users to learn text
representations and text classifiers. It works on standard, generic hardware. Models can later
be reduced in size to even fit on mobile devices. (Source from https://fasttext.cc/)
References
----------
Enriching Word Vectors with Subword Information.
Piotr Bojanowski, Edouard Grave, Armand Joulin, and Tomas Mikolov.
https://arxiv.org/abs/1607.04606
Bag of Tricks for Efficient Text Classification.
Armand Joulin, Edouard Grave, Piotr Bojanowski, and Tomas Mikolov.
https://arxiv.org/abs/1607.01759
FastText.zip: Compressing text classification models.
Armand Joulin, Edouard Grave, Piotr Bojanowski, Matthijs Douze, Herve Jegou,
and Tomas Mikolov.
https://arxiv.org/abs/1612.03651
For 'wiki.multi' embeddings:
Word Translation Without Parallel Data
Alexis Conneau, Guillaume Lample, Marc'Aurelio Ranzato, Ludovic Denoyer,
and Herve Jegou.
https://arxiv.org/abs/1710.04087
Website:
https://fasttext.cc/
To get the updated URLs to the externally hosted pre-trained token embedding files, visit
https://github.com/facebookresearch/fastText/blob/master/docs/pretrained-vectors.md
License for pre-trained embeddings:
https://creativecommons.org/licenses/by-sa/3.0/
Parameters
----------
pretrained_file_name : str, default 'wiki.en.vec'
The name of the pre-trained token embedding file.
embedding_root : str, default $MXNET_HOME/embeddings
The root directory for storing embedding-related files.
init_unknown_vec : callback
The callback used to initialize the embedding vector for the unknown token.
vocabulary : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None
It contains the tokens to index. Each indexed token will be associated with the loaded
embedding vectors, such as loaded from a pre-trained token embedding file. If None, all the
tokens from the loaded embedding vectors, such as loaded from a pre-trained token embedding
file, will be indexed.
"""
# Map a pre-trained token embedding archive file and its SHA-1 hash.
pretrained_archive_name_sha1 = C.FAST_TEXT_ARCHIVE_SHA1
# Map a pre-trained token embedding file and its SHA-1 hash.
pretrained_file_name_sha1 = C.FAST_TEXT_FILE_SHA1
@classmethod
def _get_download_file_name(cls, pretrained_file_name):
# Map a pre-trained embedding file to its archive to download.
return '.'.join(pretrained_file_name.split('.')[:-1])+'.zip'
def __init__(self, pretrained_file_name='wiki.simple.vec',
embedding_root=os.path.join(base.data_dir(), 'embeddings'),
init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
FastText._check_pretrained_file_names(pretrained_file_name)
super(FastText, self).__init__(**kwargs)
pretrained_file_path = FastText._get_pretrained_file(embedding_root, pretrained_file_name)
self._load_embedding(pretrained_file_path, ' ', init_unknown_vec)
if vocabulary is not None:
self._build_embedding_for_vocabulary(vocabulary)
[docs]class CustomEmbedding(_TokenEmbedding):
"""User-defined token embedding.
This is to load embedding vectors from a user-defined pre-trained text embedding file.
Denote by '[ed]' the argument `elem_delim`. Denote by [v_ij] the j-th element of the token
embedding vector for [token_i], the expected format of a custom pre-trained token embedding file
is:
'[token_1][ed][v_11][ed][v_12][ed]...[ed][v_1k]\\\\n[token_2][ed][v_21][ed][v_22][ed]...[ed]
[v_2k]\\\\n...'
where k is the length of the embedding vector `vec_len`.
Parameters
----------
pretrained_file_path : str
The path to the custom pre-trained token embedding file.
elem_delim : str, default ' '
The delimiter for splitting a token and every embedding vector element value on the same
line of the custom pre-trained token embedding file.
encoding : str, default 'utf8'
The encoding scheme for reading the custom pre-trained token embedding file.
init_unknown_vec : callback
The callback used to initialize the embedding vector for the unknown token.
vocabulary : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None
It contains the tokens to index. Each indexed token will be associated with the loaded
embedding vectors, such as loaded from a pre-trained token embedding file. If None, all the
tokens from the loaded embedding vectors, such as loaded from a pre-trained token embedding
file, will be indexed.
"""
def __init__(self, pretrained_file_path, elem_delim=' ', encoding='utf8',
init_unknown_vec=nd.zeros, vocabulary=None, **kwargs):
super(CustomEmbedding, self).__init__(**kwargs)
self._load_embedding(pretrained_file_path, elem_delim, init_unknown_vec, encoding)
if vocabulary is not None:
self._build_embedding_for_vocabulary(vocabulary)
[docs]class CompositeEmbedding(_TokenEmbedding):
"""Composite token embeddings.
For each indexed token in a vocabulary, multiple embedding vectors, such as concatenated
multiple embedding vectors, will be associated with it. Such embedding vectors can be loaded
from externally hosted or custom pre-trained token embedding files, such as via token embedding
instances.
Parameters
----------
vocabulary : :class:`~mxnet.contrib.text.vocab.Vocabulary`
For each indexed token in a vocabulary, multiple embedding vectors, such as concatenated
multiple embedding vectors, will be associated with it.
token_embeddings : instance or list of `mxnet.contrib.text.embedding._TokenEmbedding`
One or multiple pre-trained token embeddings to load. If it is a list of multiple
embeddings, these embedding vectors will be concatenated for each token.
"""
def __init__(self, vocabulary, token_embeddings):
# Sanity checks.
assert isinstance(vocabulary, vocab.Vocabulary), \
'The argument `vocabulary` must be an instance of ' \
'mxnet.contrib.text.indexer.Vocabulary.'
if not isinstance(token_embeddings, list):
token_embeddings = [token_embeddings]
for embed in token_embeddings:
assert isinstance(embed, _TokenEmbedding), \
'The argument `token_embeddings` must be an instance or a list of instances ' \
'of `mxnet.contrib.text.embedding.TextEmbedding` whose embedding vectors will be' \
'loaded or concatenated-then-loaded to map to the indexed tokens.'
# Index tokens.
self._index_tokens_from_vocabulary(vocabulary)
# Set _idx_to_vec so that indices of tokens from keys of `counter` are associated with token
# embedding vectors from `token_embeddings`.
self._set_idx_to_vec_by_embeddings(token_embeddings, len(self), self.idx_to_token)