def get_models(num_channels, d_conv_dim, g_conv_dim, z_dim, num_gpu,
d_model_restore=None, g_model_restore=None):
"""Get models with cuda and inited weights."""
D = Discriminator(num_channels=num_channels,
conv_dim=d_conv_dim,
num_gpu=num_gpu)
G = Generator(num_channels=num_channels,
z_dim=z_dim,
conv_dim=g_conv_dim,
num_gpu=num_gpu)
# init weights of models
D.apply(init_weights)
G.apply(init_weights)
# restore model weights
if d_model_restore is not None and os.path.exists(d_model_restore):
D.load_state_dict(torch.load(d_model_restore))
if g_model_restore is not None and os.path.exists(g_model_restore):
G.load_state_dict(torch.load(g_model_restore))
# check if cuda is available
if torch.cuda.is_available():
cudnn.benchmark = True
D.cuda()
G.cuda()
return D, G
评论列表
文章目录