predict_regression.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:chainer_sklearn 作者: corochann 项目源码 文件源码
def main():
    parser = argparse.ArgumentParser(description='Regression predict')
    parser.add_argument('--modelpath', '-m', default='result/mlp.model',
                        help='Model path to be loaded')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--unit', '-u', type=int, default=50,
                        help='Number of units')
    args = parser.parse_args()

    batchsize = 128

    # Load dataset
    data, target = load_data()
    X = data.reshape((-1, 1)).astype(np.float32)
    y = target.reshape((-1, 1)).astype(np.float32)

    # Load trained model
    model = SklearnWrapperRegressor(MLP(args.unit, 1), device=args.gpu)
    serializers.load_npz(args.modelpath, model)

    # --- Example 1. Predict all test data ---
    outputs = model.predict(X,
                            batchsize=batchsize,
                            retain_inputs=False,)

    # --- Plot result ---
    plt.figure()
    plt.scatter(X, y, label='actual')
    plt.plot(X, outputs, label='predict', color='red')
    plt.legend()
    plt.show()
    plt.savefig('predict.png')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号