model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:chainer_nmt 作者: odashi 项目源码 文件源码
def _encode(self, x_list):
    batch_size = len(x_list[0])
    source_length = len(x_list)

    # Encoding
    fc = bc = f = b = _zeros((batch_size, self.hidden_size))
    i_list = [self.x_i(_mkivar(x)) for x in x_list]
    f_list = []
    b_list = []
    for i in i_list:
      fc, f = F.lstm(fc, self.i_f(i) + self.f_f(f))
      f_list.append(f)
    for i in reversed(i_list):
      bc, b = F.lstm(bc, self.i_b(i) + self.b_b(b))
      b_list.append(b)
    b_list.reverse()

    # Making concatenated matrix
    # {f,b}_mat: shape = [batch, srclen, hidden]
    f_mat = F.concat([F.expand_dims(f, 1) for f in f_list], 1)
    b_mat = F.concat([F.expand_dims(b, 1) for b in b_list], 1)
    # fb_mat: shape = [batch, srclen, 2 * hidden]
    fb_mat = F.concat([f_mat, b_mat], 2)
    # fbe_mat: shape = [batch * srclen, atten]
    fbe_mat = self.fb_e(
        F.reshape(fb_mat, [batch_size * source_length, 2 * self.hidden_size]))

    return fb_mat, fbe_mat, fc, bc, f_list[-1], b_list[0]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号