def get_models():
"""Get models with cuda and inited weights."""
D = Discriminator(num_channels=params.num_channels,
conv_dim=params.d_conv_dim,
image_size=params.image_size,
num_gpu=params.num_gpu,
num_extra_layers=params.num_extra_layers,
use_BN=True)
G = Generator(num_channels=params.num_channels,
z_dim=params.z_dim,
conv_dim=params.g_conv_dim,
image_size=params.image_size,
num_gpu=params.num_gpu,
num_extra_layers=params.num_extra_layers,
use_BN=params.use_BN)
# init weights of models
D.apply(init_weights)
G.apply(init_weights)
# restore model weights
if params.d_model_restore is not None and \
os.path.exists(params.d_model_restore):
D.load_state_dict(torch.load(params.d_model_restore))
if params.g_model_restore is not None and \
os.path.exists(params.g_model_restore):
G.load_state_dict(torch.load(params.g_model_restore))
# check if cuda is available
if torch.cuda.is_available():
cudnn.benchmark = True
D.cuda()
G.cuda()
print(D)
print(G)
return D, G
评论列表
文章目录