models.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:GAN-Zoo 作者: corenel 项目源码 文件源码
def get_models():
    """Get models with cuda and inited weights."""
    D = Discriminator(num_channels=params.num_channels,
                      conv_dim=params.d_conv_dim,
                      image_size=params.image_size,
                      num_gpu=params.num_gpu,
                      num_extra_layers=params.num_extra_layers,
                      use_BN=True)
    G = Generator(num_channels=params.num_channels,
                  z_dim=params.z_dim,
                  conv_dim=params.g_conv_dim,
                  image_size=params.image_size,
                  num_gpu=params.num_gpu,
                  num_extra_layers=params.num_extra_layers,
                  use_BN=params.use_BN)

    # init weights of models
    D.apply(init_weights)
    G.apply(init_weights)

    # restore model weights
    if params.d_model_restore is not None and \
            os.path.exists(params.d_model_restore):
        D.load_state_dict(torch.load(params.d_model_restore))
    if params.g_model_restore is not None and \
            os.path.exists(params.g_model_restore):
        G.load_state_dict(torch.load(params.g_model_restore))

    # check if cuda is available
    if torch.cuda.is_available():
        cudnn.benchmark = True
        D.cuda()
        G.cuda()

    print(D)
    print(G)

    return D, G
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号