def __init__(self, nic, noc, ngf, ndf, beta=0.5, lamb=100, lr=1e-3, cuda=True, crayon=False):
"""
Args:
nic: Number of input channel
noc: Number of output channels
ngf: Number of generator filters
ndf: Number of discriminator filters
lamb: Weight on L1 term in objective
"""
self.cuda = cuda
self.start_epoch = 0
self.crayon = crayon
if crayon:
self.cc = CrayonClient(hostname="localhost", port=8889)
try:
self.logger = self.cc.create_experiment('pix2pix')
except:
self.cc.remove_experiment('pix2pix')
self.logger = self.cc.create_experiment('pix2pix')
self.gen = self.cudafy(Generator(nic, noc, ngf))
self.dis = self.cudafy(Discriminator(nic, noc, ndf))
# Optimizers for generators
self.gen_optim = self.cudafy(optim.Adam(
self.gen.parameters(), lr=lr, betas=(beta, 0.999)))
# Optimizers for discriminators
self.dis_optim = self.cudafy(optim.Adam(
self.dis.parameters(), lr=lr, betas=(beta, 0.999)))
# Loss functions
self.criterion_bce = nn.BCELoss()
self.criterion_mse = nn.MSELoss()
self.criterion_l1 = nn.L1Loss()
self.lamb = lamb
评论列表
文章目录