def decode(self, encoding, input, output):
"""
Single training example decoding function
:param encoding: last hidden state from encoder
:param input: source sentence
:param output: target sentence
:return: loss value
"""
src_toks = [self.src_vocab[tok] for tok in input]
tgt_toks = [self.tgt_vocab[tok] for tok in output]
w = dynet.parameter(self.decoder_w)
b = dynet.parameter(self.decoder_b)
s = self.dec_lstm.initial_state().add_input(encoding)
loss = []
sent = []
for tok in tgt_toks:
out_vector = dynet.affine_transform([b, w, s.output()])
probs = dynet.softmax(out_vector)
cross_ent_loss = - dynet.log(dynet.pick(probs, tok.i))
loss.append(cross_ent_loss)
embed_vector = self.tgt_lookup[tok.i]
s = s.add_input(embed_vector)
loss = dynet.esum(loss)
return loss
评论列表
文章目录