train.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:gconv_experiments 作者: tscohen 项目源码 文件源码
def get_model_and_optimizer(result_dir, modelfn, opt, opt_kwargs, net_kwargs, gpu):
    model_fn = os.path.basename(modelfn)
    model_name = model_fn.split('.')[0]
    module = imp.load_source(model_name, modelfn)
    net = getattr(module, model_name)

    # Copy model definition and this train script to the result dir
    dst = '%s/%s' % (result_dir, model_fn)
    if not os.path.exists(dst):
        shutil.copy(modelfn, dst)
    dst = '%s/%s' % (result_dir, os.path.basename(__file__))
    if not os.path.exists(dst):
        shutil.copy(__file__, dst)

    # Create model
    model = net(**net_kwargs)
    if gpu >= 0:
        model.to_gpu(gpu)

    # Create optimizer
    optimizer = optimizers.__dict__[opt](**opt_kwargs)
    optimizer.setup(model)

    return model, optimizer
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号