def main():
parser = argparse.ArgumentParser(description='Regression predict')
parser.add_argument('--modelpath', '-m', default='result/mlp.model',
help='Model path to be loaded')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--unit', '-u', type=int, default=50,
help='Number of units')
args = parser.parse_args()
batchsize = 128
# Load dataset
data, target = load_data()
X = data.reshape((-1, 1)).astype(np.float32)
y = target.reshape((-1, 1)).astype(np.float32)
# Load trained model
model = SklearnWrapperRegressor(MLP(args.unit, 1), device=args.gpu)
serializers.load_npz(args.modelpath, model)
# --- Example 1. Predict all test data ---
outputs = model.predict(X,
batchsize=batchsize,
retain_inputs=False,)
# --- Plot result ---
plt.figure()
plt.scatter(X, y, label='actual')
plt.plot(X, outputs, label='predict', color='red')
plt.legend()
plt.show()
plt.savefig('predict.png')
评论列表
文章目录