mixGau_cycleGAN.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:probability_GAN 作者: MaureenZOU 项目源码 文件源码
def __init__(self):
        #initialize network for cycleGAN
        self.netG_A = Generator(input_size = g_input_size, hidden_size = g_hidden_size, output_size = g_output_size)
        #self.netG_A = torch.nn.DataParallel(self.netG_A)
        self.netG_B = Generator(input_size = g_input_size, hidden_size = g_hidden_size, output_size = g_output_size)
        #self.netG_B = torch.nn.DataParallel(self.netG_B)
        self.netD_A = Discriminator(input_size = d_input_size, hidden_size = d_hidden_size, output_size = d_output_size)
        #self.netD_A = torch.nn.DataParallel(self.netD_A)
        self.netD_B = Discriminator(input_size = d_input_size, hidden_size = d_hidden_size, output_size = d_output_size)
        #self.netD_B = torch.nn.DataParallel(self.netD_B)

        print('---------- Networks initialized -------------')

        #initialize loss function
        self.criterionGAN = GANLoss()
        self.criterionCycle = torch.nn.L1Loss()

        #initialize optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), 
            lr = d_learning_rate, betas = optim_betas)
        self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr = d_learning_rate, betas = optim_betas, weight_decay = l2)
        self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr = d_learning_rate, betas = optim_betas, weight_decay = l2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号