trainer.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:chainer-cf-nade 作者: dsanno 项目源码 文件源码
def __forward(self, batch_x, batch_t, weight, train=True):
        xp = self.xp
        x = Variable(xp.asarray(batch_x), volatile=not train)
        t = Variable(xp.asarray(batch_t), volatile=not train)
        y = self.net(x, train=train)

        b, c, n = y.data.shape
        mask = Variable(xp.asarray(np.broadcast_to(weight.reshape(-1, 1, 1), (b, c, n)) * loss_mask(batch_t, self.net.rating_num)), volatile=not train)
        if self.ordinal_weight == 0:
            loss = F.sum(-F.log_softmax(y) * mask) / b
        elif self.ordinal_weight == 1:
            loss = ordinal_loss(y, mask)
        else:
            loss = (1 - self.ordinal_weight) * F.sum(-F.log_softmax(y) * mask) / b + self.ordinal_weight * ordinal_loss(y, mask)

        acc = self.__accuracy(y, t)
        return loss, acc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号