def __init__(self):
#initialize network for cycleGAN
self.netG_A = Generator(input_size = g_input_size, hidden_size = g_hidden_size, output_size = g_output_size)
self.netG_B = Generator(input_size = g_input_size, hidden_size = g_hidden_size, output_size = g_output_size)
self.netD_A = Discriminator(input_size = d_input_size, hidden_size = d_hidden_size, output_size = d_output_size)
self.netD_B = Discriminator(input_size = d_input_size, hidden_size = d_hidden_size, output_size = d_output_size)
print('---------- Networks initialized -------------')
#initialize loss function
self.criterionGAN = GANLoss()
self.criterionCycle = torch.nn.L1Loss()
#initialize optimizers
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
lr = d_learning_rate, betas = optim_betas)
self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr = d_learning_rate, betas = optim_betas)
self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr = d_learning_rate, betas = optim_betas)
评论列表
文章目录