def __init__(self, *, iterator, noise_iterator, optimizer_generator,
optimizer_critic, device=-1):
if optimizer_generator.target.name is None:
optimizer_generator.target.name = 'generator'
if optimizer_critic.target.name is None:
optimizer_critic.target.name = 'critic'
iterators = {'main': iterator, 'z': noise_iterator}
optimizers = {'generator': optimizer_generator,
'critic': optimizer_critic}
super().__init__(iterators, optimizers, device=device)
if device >= 0:
cuda.get_device(device).use()
[optimizer.target.to_gpu() for optimizer in optimizers.values()]
self.xp = cuda.cupy if device >= 0 else np
评论列表
文章目录