base.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:motion-classification 作者: matthiasplappert 项目源码 文件源码
def fit(self, X, y):
        assert isinstance(X, list)  #TODO: this should not be an assert
        assert len(y) > 0
        assert len(X) == len(y)

        X = pad_sequences(X)
        print X.shape, y.shape

        n_features = X.shape[2]
        self.n_labels_ = y.shape[1]
        print n_features, self.n_labels_

        model = Sequential()
        model.add(GRU(n_features, 128))
        model.add(Dropout(0.1))
        model.add(BatchNormalization(128))
        model.add(Dense(128, self.n_labels_))
        model.add(Activation('sigmoid'))

        sgd = opt.SGD(lr=0.005, decay=1e-6, momentum=0., nesterov=True)
        model.compile(loss='categorical_crossentropy', optimizer=sgd, class_mode='categorical')

        model.fit(X, y, batch_size=self.n_batch_size, nb_epoch=self.n_epochs, show_accuracy=True)
        self.model_ = model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号