train_utils.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:chainer-segnet 作者: pfnet-research 项目源码 文件源码
def get_model(
        model_file, model_name, loss_file, loss_name, class_weight, n_encdec,
        n_classes, in_channel, n_mid, train_depth=None, result_dir=None):
    model = imp.load_source(model_name, model_file)
    model = getattr(model, model_name)
    loss = imp.load_source(loss_name, loss_file)
    loss = getattr(loss, loss_name)

    # Initialize
    model = model(n_encdec, n_classes, in_channel, n_mid)
    if train_depth:
        model = loss(model, class_weight, train_depth)

    # Copy files
    if result_dir is not None:
        base_fn = os.path.basename(model_file)
        dst = '{}/{}'.format(result_dir, base_fn)
        if not os.path.exists(dst):
            shutil.copy(model_file, dst)
        base_fn = os.path.basename(loss_file)
        dst = '{}/{}'.format(result_dir, base_fn)
        if not os.path.exists(dst):
            shutil.copy(loss_file, dst)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号