def __init__(self, *, iterator, noise_iterator, optimizer_generator,
optimizer_discriminator, generator_lr_decay_interval,
discriminator_lr_decay_interval, gamma, k_0, lambda_k,
loss_norm, device=-1):
iterators = {'main': iterator, 'z': noise_iterator}
optimizers = {'gen': optimizer_generator,
'dis': optimizer_discriminator}
super().__init__(iterators, optimizers, device=device)
self.gen_lr_decay_interval = generator_lr_decay_interval
self.dis_lr_decay_interval = discriminator_lr_decay_interval
self.k = k_0
self.lambda_k = lambda_k
self.gamma = gamma
self.loss_norm = loss_norm
if device >= 0:
cuda.get_device(device).use()
for optimizer in optimizers.values():
optimizer.target.to_gpu()
评论列表
文章目录