def decode(self, input_vectors, output):
tgt_toks = [self.tgt_vocab[tok] for tok in output]
w = dynet.parameter(self.decoder_w)
b = dynet.parameter(self.decoder_b)
s = self.dec_lstm.initial_state()
s = s.add_input(dynet.concatenate([
input_vectors[-1],
dynet.vecInput(self.args.hidden_dim)
]))
loss = []
for tok in tgt_toks:
out_vector = dynet.affine_transform([b, w, s.output()])
probs = dynet.softmax(out_vector)
loss.append(-dynet.log(dynet.pick(probs, tok.i)))
embed_vector = self.tgt_lookup[tok.i]
attn_vector = self.attend(input_vectors, s)
inp = dynet.concatenate([embed_vector, attn_vector])
s = s.add_input(inp)
loss = dynet.esum(loss)
return loss
评论列表
文章目录