train.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:gconv_experiments 作者: tscohen 项目源码 文件源码
def get_model_and_optimizer(result_dir, modelfn, opt, opt_kwargs, net_kwargs, gpu):
    model_fn = os.path.basename(modelfn)
    model_name = model_fn.split('.')[0]
    module = imp.load_source(model_name, modelfn)
    Net = getattr(module, model_name)

    dst = '%s/%s' % (result_dir, model_fn)
    if not os.path.exists(dst):
        shutil.copy(modelfn, dst)

    dst = '%s/%s' % (result_dir, os.path.basename(__file__))
    if not os.path.exists(dst):
        shutil.copy(__file__, dst)

    # prepare model
    model = Net(**net_kwargs)
    if gpu >= 0:
        model.to_gpu()

    optimizer = optimizers.__dict__[opt](**opt_kwargs)
    optimizer.setup(model)

    return model, optimizer
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号