def loss(self, batch_data, test=False):
# unpack data
(source, targets), conds = batch_data, None
if self.conds is not None:
(source, *conds), (targets, *_) = source, targets
# get hidden from previous batch (if stored)
hidden = self.hidden_state.get('hidden', None)
# run RNN
output, hidden, _ = self(source, hidden=hidden, conds=conds)
# store hidden for next batch
self.hidden_state['hidden'] = u.repackage_hidden(hidden)
# compute loss and backward
loss = F.nll_loss(output, targets.view(-1), size_average=True)
if not test:
loss.backward()
return (loss.data[0], ), source.nelement()
评论列表
文章目录