def get_batch_loss(self, input_batch, output_batch):
dynet.renew_cg()
# Dimension: maxSentLength * minibatch_size
wids = []
wids_reversed = []
# List of lists to store whether an input is
# present(1)/absent(0) for an example at a time step
# masks = [] # Dimension: maxSentLength * minibatch_size
# tot_words = 0
maxSentLength = max([len(sent) for sent in input_batch])
for j in range(maxSentLength):
wids.append([(self.src_vocab[sent[j]].i if len(sent)>j else self.src_vocab.END_TOK.i) for sent in input_batch])
wids_reversed.append([(self.src_vocab[sent[len(sent)- j-1]].i if len(sent)>j else self.src_vocab.END_TOK.i) for sent in input_batch])
# mask = [(1 if len(sent)>j else 0) for sent in input_batch]
# masks.append(mask)
#tot_words += sum(mask)
embedded_batch = self.embed_batch_seq(wids)
embedded_batch_reverse = self.embed_batch_seq(wids_reversed)
encoded_batch = self.encode_batch_seq(embedded_batch, embedded_batch_reverse)
# pass last hidden state of encoder to decoder
return self.decode_batch(encoded_batch, output_batch)
评论列表
文章目录