trainer.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:rarepepes 作者: kendricktan 项目源码 文件源码
def __init__(self, nic, noc, ngf, ndf, beta=0.5, lamb=100, lr=1e-3, cuda=True, crayon=False):
        """
        Args:
            nic: Number of input channel
            noc: Number of output channels
            ngf: Number of generator filters
            ndf: Number of discriminator filters
            lamb: Weight on L1 term in objective
        """
        self.cuda = cuda
        self.start_epoch = 0

        self.crayon = crayon
        if crayon:
            self.cc = CrayonClient(hostname="localhost", port=8889)

            try:
                self.logger = self.cc.create_experiment('pix2pix')
            except:
                self.cc.remove_experiment('pix2pix')
                self.logger = self.cc.create_experiment('pix2pix')

        self.gen = self.cudafy(Generator(nic, noc, ngf))
        self.dis = self.cudafy(Discriminator(nic, noc, ndf))

        # Optimizers for generators
        self.gen_optim = self.cudafy(optim.Adam(
            self.gen.parameters(), lr=lr, betas=(beta, 0.999)))

        # Optimizers for discriminators
        self.dis_optim = self.cudafy(optim.Adam(
            self.dis.parameters(), lr=lr, betas=(beta, 0.999)))

        # Loss functions
        self.criterion_bce = nn.BCELoss()
        self.criterion_mse = nn.MSELoss()
        self.criterion_l1 = nn.L1Loss()

        self.lamb = lamb
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号