test_tasks.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:deep-coref 作者: clarkkev 项目源码 文件源码
def test_temporal_clf(self):
        print('temporal classification data:')
        (X_train, y_train), (X_test, y_test) = get_test_data(nb_train=1000, nb_test=200, input_shape=(3, 5),
                                                             classification=True, nb_class=2)
        print('X_train:', X_train.shape)
        print('X_test:', X_test.shape)
        print('y_train:', y_train.shape)
        print('y_test:', y_test.shape)

        y_train = to_categorical(y_train)
        y_test = to_categorical(y_test)

        model = Sequential()
        model.add(GRU(y_train.shape[-1], input_shape=(None, X_train.shape[-1])))
        model.add(Activation('softmax'))
        model.compile(loss='categorical_crossentropy', optimizer='adadelta')
        history = model.fit(X_train, y_train, nb_epoch=12, batch_size=16, validation_data=(X_test, y_test), show_accuracy=True, verbose=2)
        self.assertTrue(history.history['val_acc'][-1] > 0.9)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号