kera.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Quantrade 作者: quant-trade 项目源码 文件源码
def main():
    validate = True
    n = SData(validate=validate)

    Xtrain = n.train_features.as_matrix()
    ytrain = n.train_targets
    Xtest = n.test_features.as_matrix()
    ytest = n.test_targets

    Xtrain = np.reshape(Xtrain, (Xtrain.shape[0], Xtrain.shape[1], 1))
    Xtest  = np.reshape(Xtest, (Xtest.shape[0],  Xtest.shape[1], 1))

    rnn = RNN([1, 100, 100, 1])
    rnn.fit(Xtrain, ytrain)
    p = rnn.predict(Xtest)
    p_prob = rnn.predict(Xtest)

    if validate:
        mse = mean_squared_error(ytest, p)
        print("MSE: {}".format(mse))
        loss = log_loss(ytest, p_prob)
        print("Log loss: {}".format(loss))
    else:
        base_path = dirname(__file__)
        results_df = DataFrame(data={'probability':results})
        joined = DataFrame(t_id).join(results_df)
        joined.to_csv(join(base_path, 'results', 'dl.csv'), index=False)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号