test_tasks.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:deep-coref 作者: clarkkev 项目源码 文件源码
def test_seq_to_seq(self):
        print('sequence to sequence data:')
        (X_train, y_train), (X_test, y_test) = get_test_data(nb_train=1000, nb_test=200, input_shape=(3, 5), output_shape=(3, 5),
                                                             classification=False)
        print('X_train:', X_train.shape)
        print('X_test:', X_test.shape)
        print('y_train:', y_train.shape)
        print('y_test:', y_test.shape)

        model = Sequential()
        model.add(TimeDistributedDense(y_train.shape[-1], input_shape=(None, X_train.shape[-1])))
        model.compile(loss='hinge', optimizer='rmsprop')
        history = model.fit(X_train, y_train, nb_epoch=12, batch_size=16, validation_data=(X_test, y_test), verbose=2)
        self.assertTrue(history.history['val_loss'][-1] < 0.8)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号