predict.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:golden_touch 作者: at553 项目源码 文件源码
def train_model(self):
        # scale
        scaler = MinMaxScaler(feature_range=(0, 1))
        dataset = scaler.fit_transform(self.data)

        # split into train and test sets
        train_size = int(len(dataset) * 0.95)
        train, test = dataset[0:train_size, :], dataset[train_size:len(dataset), :]

        look_back = 5
        trainX, trainY = self.create_dataset(train, look_back)

        # reshape input to be [samples, time steps, features]
        trainX = numpy.reshape(trainX, (trainX.shape[0], 1, trainX.shape[1]))
        # create and fit the LSTM network
        model = Sequential()
        model.add(LSTM(6, input_dim=look_back))
        model.add(Dense(1))
        model.compile(loss='mean_squared_error', optimizer='adam')
        model.fit(trainX, trainY, nb_epoch=100, batch_size=1, verbose=2)
        return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号