def __init__(self, batch_size, d_W, d_L):
"""
batch_size = batch size used in training/validation (mandatory because of stateful LSTMs)
n_ctx = context size in training/validation
d_W = word features (of output word embeddings from C2W sub-model)
d_L = language model hidden state size
"""
def masked_ctx(emb, mask):
class L(Lambda):
def __init__(self):
super(L, self).__init__(lambda x: x[0] * K.expand_dims(x[1], -1), lambda input_shapes: input_shapes[0])
def compute_mask(self, x, input_mask=None):
return K.expand_dims(x[1], -1)
return L()([Reshape((1, d_W))(emb), mask])
self._saved_states = None
self._lstms = []
ctx_emb = Input(batch_shape=(batch_size, d_W), name='ctx_emb')
ctx_mask = Input(batch_shape=(batch_size,), name='ctx_mask')
C = masked_ctx(ctx_emb, ctx_mask)
for i in range(NUM_LSTMs):
lstm = LSTM(d_L,
return_sequences=(i < NUM_LSTMs - 1),
stateful=True,
consume_less='gpu')
self._lstms.append(lstm)
C = lstm(C)
super(LanguageModel, self).__init__(input=[ctx_emb, ctx_mask], output=C, name='LanguageModel')
评论列表
文章目录