def build_model(self):
"""Builds a generator and a discriminator."""
self.g12 = G12(self.config, conv_dim=self.g_conv_dim)
self.g21 = G21(self.config, conv_dim=self.g_conv_dim)
self.d1 = D1(conv_dim=self.d_conv_dim)
self.d2 = D2(conv_dim=self.d_conv_dim)
g_params = list(self.g12.parameters()) + list(self.g21.parameters())
d_params = list(self.d1.parameters()) + list(self.d2.parameters())
self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2])
if torch.cuda.is_available():
self.g12.cuda()
self.g21.cuda()
self.d1.cuda()
self.d2.cuda()
评论列表
文章目录