def __call__(self, x, context):
x = F.broadcast_to(x[:, None], (context.shape[0], context.shape[1]))
x = F.reshape(x, (context.shape[0] * context.shape[1],))
context = context.reshape((context.shape[0] * context.shape[1]))
e = self.rnn.charRNN(context)
loss = self.loss_func(e, x)
reporter.report({'loss': loss}, self)
return loss
评论列表
文章目录