nn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:stock-price-prediction 作者: chinuy 项目源码 文件源码
def main():

    stock_name = 'SPY'
    delta = 4
    start = datetime.datetime(2010,1,1)
    end = datetime.datetime(2015,12,31)
    start_test = datetime.datetime(2015,1,1)

    dataset = util.get_data(stock_name, start, end)
    delta = range(1, delta)
    dataset = util.applyFeatures(dataset, delta)
    dataset = util.preprocessData(dataset)
    X_train, y_train, X_test, y_test  = \
        classifier.prepareDataForClassification(dataset, start_test)

    X_train = numpy.reshape(numpy.array(X_train), (X_train.shape[0], 1, X_train.shape[1]))

    X_test = numpy.reshape(numpy.array(X_test), (X_test.shape[0], 1, X_test.shape[1]))

    #Step 2 Build Model
    model = Sequential()

    model.add(LSTM(
        128,
        input_shape=(None, X_train.shape[2]),
        return_sequences=True))
    model.add(Dropout(0.2))

    model.add(LSTM(
        240,
        return_sequences=False))
    model.add(Dropout(0.2))

    model.add(Dense(
        units=1))
    model.add(Activation('sigmoid'))

    start = time.time()
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

    #Step 3 Train the model
    model.fit(
        X_train,
        y_train,
        batch_size=4,
        epochs=4,
        validation_split=0.1)

    print model.predict(X_train)
    print model.evaluate(X_train, y_train)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号