sequence2sequence.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:lang-reps 作者: chaitanyamalaviya 项目源码 文件源码
def encode_batch_seq(self, src_seq, src_seq_rev):

        forward_states = self.enc_fwd_lstm.initial_state().add_inputs(src_seq)
        backward_states = self.enc_bwd_lstm.initial_state().add_inputs(src_seq_rev)[::-1]

        src_encodings = []
        forward_cells = []
        backward_cells = []
        for forward_state, backward_state in zip(forward_states, backward_states):
            fwd_cell, fwd_enc = forward_state.s()
            bak_cell, bak_enc = backward_state.s()

            src_encodings.append(dynet.concatenate([fwd_enc, bak_enc]))
            forward_cells.append(fwd_cell)
            backward_cells.append(bak_cell)

        decoder_init = dynet.concatenate([forward_cells[-1], backward_cells[0]])
    decoder_all = [dynet.concatenate([fwd, bwd]) for fwd, bwd in zip(forward_cells, list(reversed(backward_cells)))]
        return src_encodings, decoder_all
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号