lstm.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:mars_express 作者: wsteitz 项目源码 文件源码
def fit(self, x, y):
        input_dim = x.shape[1]
        output_dim = y.shape[1]
        self.x_train = x

        start = len(x) % (self.batch_size * self.sequence_length)

        x_seq = self.sliding_window(x.iloc[start:])
        y_seq = self.sliding_window(y.iloc[start:])

        model = Sequential()
        model.add(GRU(1024, batch_input_shape=(self.batch_size, self.sequence_length, input_dim), return_sequences=True, stateful=True))
        model.add(Activation("tanh"))
        model.add(GRU(1024, return_sequences=True))
        model.add(Activation("tanh"))
        model.add(GRU(512, return_sequences=True))
        model.add(Activation("tanh"))
        #model.add(Dropout(0.5))
        model.add(TimeDistributed(Dense(output_dim)))
        model.add(Activation("linear"))

        optimizer = keras.optimizers.RMSprop(lr=0.002)
        optimizer = keras.optimizers.Nadam(lr=0.002)
        model.compile(loss='mse', optimizer=optimizer)

        model.fit(x_seq, y_seq, batch_size=self.batch_size, verbose=1, nb_epoch=self.n_epochs, shuffle=False)
        self.model = model
        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号