def validater(self, batch_loader):
def validate(batch_size, use_cuda):
input = batch_loader.next_batch(batch_size, 'valid')
input = [Variable(t.from_numpy(var)) for var in input]
input = [var.long() for var in input]
input = [var.cuda() if use_cuda else var for var in input]
[encoder_word_input, encoder_character_input, decoder_word_input, decoder_character_input, target] = input
logits, _, kld = self(0.,
encoder_word_input, encoder_character_input,
decoder_word_input, decoder_character_input,
z=None)
logits = logits.view(-1, self.params.word_vocab_size)
target = target.view(-1)
cross_entropy = F.cross_entropy(logits, target)
return cross_entropy, kld
return validate
评论列表
文章目录