def __call__(self, x, context):
e = self.embed(context)
shape = e.shape
x = F.broadcast_to(x[:, None], (shape[0], shape[1]))
e = F.reshape(e, (shape[0] * shape[1], shape[2]))
x = F.reshape(x, (shape[0] * shape[1],))
loss = self.loss_func(e, x)
reporter.report({'loss': loss}, self)
return loss
评论列表
文章目录