models.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:GAN-Zoo 作者: corenel 项目源码 文件源码
def get_models(num_channels, d_conv_dim, g_conv_dim, z_dim, num_gpu,
               d_model_restore=None, g_model_restore=None):
    """Get models with cuda and inited weights."""
    D = Discriminator(num_channels=num_channels,
                      conv_dim=d_conv_dim,
                      num_gpu=num_gpu)
    G = Generator(num_channels=num_channels,
                  z_dim=z_dim,
                  conv_dim=g_conv_dim,
                  num_gpu=num_gpu)

    # init weights of models
    D.apply(init_weights)
    G.apply(init_weights)

    # restore model weights
    if d_model_restore is not None and os.path.exists(d_model_restore):
        D.load_state_dict(torch.load(d_model_restore))
    if g_model_restore is not None and os.path.exists(g_model_restore):
        G.load_state_dict(torch.load(g_model_restore))

    # check if cuda is available
    if torch.cuda.is_available():
        cudnn.benchmark = True
        D.cuda()
        G.cuda()

    return D, G
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号