def encode_batch_seq(self, src_seq, src_seq_rev):
forward_states = self.enc_fwd_lstm.initial_state().add_inputs(src_seq)
backward_states = self.enc_bwd_lstm.initial_state().add_inputs(src_seq_rev)[::-1]
src_encodings = []
forward_cells = []
backward_cells = []
for forward_state, backward_state in zip(forward_states, backward_states):
fwd_cell, fwd_enc = forward_state.s()
bak_cell, bak_enc = backward_state.s()
src_encodings.append(dynet.concatenate([fwd_enc, bak_enc]))
forward_cells.append(fwd_cell)
backward_cells.append(bak_cell)
decoder_init = dynet.concatenate([forward_cells[-1], backward_cells[0]])
decoder_all = [dynet.concatenate([fwd, bwd]) for fwd, bwd in zip(forward_cells, list(reversed(backward_cells)))]
return src_encodings, decoder_all
评论列表
文章目录