multi_updater.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:Comicolorization 作者: DwangoMediaVillage 项目源码 文件源码
def __init__(
            self,
            args,
            loss_maker,
            main_optimizer,
            main_lossfun,
            reinput_optimizer=None,
            reinput_lossfun=None,
            discriminator_optimizer=None,
            discriminator_lossfun=None,
            *_args, **kwargs
    ):
        # type: (any, comicolorization.loss.LossMaker, any, typing.Callable[[typing.Dict], any], typing.List[chainer.Optimizer], typing.Callable[[int, typing.Dict], any], any, typing.Callable[[typing.Dict], any], *any, **any) -> None
        optimizers = {'main': main_optimizer}
        if reinput_optimizer is not None:
            for i_reinput, optimizer in enumerate(reinput_optimizer):
                optimizers['reinput{}'.format(i_reinput)] = optimizer

        if discriminator_optimizer is not None:
            optimizers['discriminator'] = discriminator_optimizer

        super().__init__(optimizer=optimizers, *_args, **kwargs)

        # chainer.reporter cannot work on some optimizer focus same model
        if args.separate_backward_reinput and reinput_optimizer is None:
            reinput_optimizer = [main_optimizer for _ in range(len(args.loss_blend_ratio_reinput))]

        self.args = args
        self.loss_maker = loss_maker
        self.main_optimizer = main_optimizer
        self.main_lossfun = main_lossfun
        self.reinput_optimizer = reinput_optimizer
        self.reinput_lossfun = reinput_lossfun
        self.discriminator_optimizer = discriminator_optimizer
        self.discriminator_lossfun = discriminator_lossfun
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号