lm.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:c2w2c 作者: milankinen 项目源码 文件源码
def __init__(self, batch_size, d_W, d_L):
    """
      batch_size = batch size used in training/validation (mandatory because of stateful LSTMs)
      n_ctx      = context size in training/validation
      d_W        = word features (of output word embeddings from C2W sub-model)
      d_L        = language model hidden state size
    """
    def masked_ctx(emb, mask):
      class L(Lambda):
        def __init__(self):
          super(L, self).__init__(lambda x: x[0] * K.expand_dims(x[1], -1), lambda input_shapes: input_shapes[0])

        def compute_mask(self, x, input_mask=None):
          return K.expand_dims(x[1], -1)
      return L()([Reshape((1, d_W))(emb), mask])

    self._saved_states = None
    self._lstms = []

    ctx_emb   = Input(batch_shape=(batch_size, d_W), name='ctx_emb')
    ctx_mask  = Input(batch_shape=(batch_size,), name='ctx_mask')

    C = masked_ctx(ctx_emb, ctx_mask)
    for i in range(NUM_LSTMs):
      lstm = LSTM(d_L,
                  return_sequences=(i < NUM_LSTMs - 1),
                  stateful=True,
                  consume_less='gpu')
      self._lstms.append(lstm)
      C = lstm(C)

    super(LanguageModel, self).__init__(input=[ctx_emb, ctx_mask], output=C, name='LanguageModel')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号