dnn.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:StockPredictor 作者: wallsbreaker 项目源码 文件源码
def predict_classify(target_var, target_labels, model_path):
    #redefine model
    target_var = T.imatrix('y')
    target_labels = target_var
    dnn_strategy = model_path.split('/')[-1].split('_')[0]
    network = get_model_by_strategy(dnn_strategy)

    #load params
    params = []
    with open(model_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            params.append(np.array(json.loads(line)))
    set_all_param_values(network, params)

    predict_prediction = get_output(network, deterministic=True)
    predict_acc = binary_accuracy(predict_prediction, target_labels).mean()

    input_layer = get_all_layers(network)[0]
    predict = theano.function([input_layer.input_var, target_var],[predict_prediction, predict_acc])

    X, labels, values, _ = load_dataset('../../data/test')
    predict_prediction, predict_acc = predict(X, labels)

    sys.stdout.write("  predict accuracy:\t\t\t{} %\n".format(predict_acc * 100))

    #output predict result
    with open('../../data/prediction', 'w') as f:
        for ix in xrange(len(labels)):
            line = str(labels[ix]) + '\t' + str(values[ix]) + '\t' + str(predict_prediction[ix][0]) + '\n'
            f.write(line)
    sys.stdout.flush()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号