at.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:masalachai 作者: DaikiShimada 项目源码 文件源码
def supervised_update(self):
        # array backend
        xp = cuda.cupy if self.gpu >= 0 else numpy

        self.accuracy = None

        # read data
        data = self.train_data_queues[0].get()
        vx = tuple([chainer.Variable(xp.asarray(data[k]))
                    for k in data.keys() if 'data' in k])
        vt = tuple([chainer.Variable(xp.asarray(data[k]))
                    for k in data.keys() if 'target' in k])

        self.optimizer.update(self.adversarial_loss, vx, vt)

        # get result
        res = {'loss': float(self.loss.data),
               'adversarial_loss': float(self.adv_loss.data)}
        if self.accuracy is not None:
            res['accuracy'] = self.accuracy
        return res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号