def __init__(
self,
args,
loss_maker,
main_optimizer,
main_lossfun,
reinput_optimizer=None,
reinput_lossfun=None,
discriminator_optimizer=None,
discriminator_lossfun=None,
*_args, **kwargs
):
# type: (any, comicolorization.loss.LossMaker, any, typing.Callable[[typing.Dict], any], typing.List[chainer.Optimizer], typing.Callable[[int, typing.Dict], any], any, typing.Callable[[typing.Dict], any], *any, **any) -> None
optimizers = {'main': main_optimizer}
if reinput_optimizer is not None:
for i_reinput, optimizer in enumerate(reinput_optimizer):
optimizers['reinput{}'.format(i_reinput)] = optimizer
if discriminator_optimizer is not None:
optimizers['discriminator'] = discriminator_optimizer
super().__init__(optimizer=optimizers, *_args, **kwargs)
# chainer.reporter cannot work on some optimizer focus same model
if args.separate_backward_reinput and reinput_optimizer is None:
reinput_optimizer = [main_optimizer for _ in range(len(args.loss_blend_ratio_reinput))]
self.args = args
self.loss_maker = loss_maker
self.main_optimizer = main_optimizer
self.main_lossfun = main_lossfun
self.reinput_optimizer = reinput_optimizer
self.reinput_lossfun = reinput_lossfun
self.discriminator_optimizer = discriminator_optimizer
self.discriminator_lossfun = discriminator_lossfun
multi_updater.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录