updater.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:chainer-wasserstein-gan 作者: hvy 项目源码 文件源码
def __init__(self, *, iterator, noise_iterator, optimizer_generator,
                 optimizer_critic, device=-1):

        if optimizer_generator.target.name is None:
            optimizer_generator.target.name = 'generator'

        if optimizer_critic.target.name is None:
            optimizer_critic.target.name = 'critic'

        iterators = {'main': iterator, 'z': noise_iterator}
        optimizers = {'generator': optimizer_generator,
                      'critic': optimizer_critic}

        super().__init__(iterators, optimizers, device=device)

        if device >= 0:
            cuda.get_device(device).use()
            [optimizer.target.to_gpu() for optimizer in optimizers.values()]

        self.xp = cuda.cupy if device >= 0 else np
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号