def __init__(self, lstm_num_layers, lstm_layer_size, trainable_embeddings, **kw):
"""Initializes the Keras LSTM question processing component.
Args:
lstm_num_layers: Number of stacked LSTM layers.
lstm_layer_size: Dimensionality of each LSTM unit.
Keyword Args:
max_sentence_length: Maximum number of words to consider in each
question, necessary at train time.
bidirectional: Whether to use bidirectional LSTM layers.
"""
print('Loading GloVe data... ', end='', flush=True)
self._nlp = English()
print('Done.')
#embedding_dims = 300
embeddings = get_embeddings(self._nlp.vocab)
embedding_dims = embeddings.shape[1]
# TODO(Bernhard): Investigate how the LSTM parameters influence the
# overall performance.
self._max_len = kw.get('max_sentence_length', 15)
self._bidirectional = kw.get('bidirectional', False)
self._model = Sequential()
shallow = lstm_num_layers == 1 # marks a one layer LSTM
if trainable_embeddings:
# if embeddings are trainable we have to enforce CPU usage in order to not run out of memory.
# this is device dependent.
# TODO(Bernhard): preprocess questions ans vocab and try if we can get rid of enough words to make
# this run on gpu anyway
with tf.device("/cpu:0"):
self._model.add(Embedding(embeddings.shape[0], embeddings.shape[1],
input_length=self._max_len, trainable=True, weights=[embeddings]))
else:
# a non-trainable embedding layer can run on GPU without exhausting all the memory
self._model.add(Embedding(embeddings.shape[0], embeddings.shape[1],
input_length=self._max_len, trainable=False, weights=[embeddings]))
lstm = LSTM(output_dim=lstm_layer_size,
return_sequences=not shallow,
input_shape=(self._max_len, embedding_dims))
if self._bidirectional:
lstm = Bidirectional(lstm)
self._model.add(lstm)
if not shallow:
for i in range(lstm_num_layers-2):
lstm = LSTM(output_dim=lstm_layer_size, return_sequences=True)
if self._bidirectional:
lstm = Bidirectional(lstm)
self._model.add(lstm)
lstm = LSTM(output_dim=lstm_layer_size, return_sequences=False)
if self._bidirectional:
lstm = Bidirectional(lstm)
self._model.add(lstm)
评论列表
文章目录