test_model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:mycroft 作者: wpm 项目源码 文件源码
def embedding_model_train_predict_evaluate(self, model):
        # Train
        history = model.train(self.texts, self.labels, epochs=2, batch_size=10, validation_fraction=0.1,
                              model_directory=self.model_directory, verbose=0)
        self.assertIsInstance(history, History)
        self.assertTrue(os.path.exists(os.path.join(self.model_directory, "model.hd5")))
        self.assertTrue(os.path.exists(os.path.join(self.model_directory, "classifier.pk")))
        self.assertTrue(os.path.exists(os.path.join(self.model_directory, "description.txt")))
        self.assertTrue(os.path.exists(os.path.join(self.model_directory, "history.json")))
        # Predict
        loaded_model = load_embedding_model(self.model_directory)
        self.assertTrue(isinstance(loaded_model, model.__class__))
        n = len(self.texts)
        label_probabilities, predicted_labels = loaded_model.predict(self.texts)
        self.assertEqual((n, 2), label_probabilities.shape)
        self.assertEqual(numpy.dtype("float32"), label_probabilities.dtype)
        self.assertEqual(n, len(predicted_labels))
        self.assertTrue(set(predicted_labels).issubset({"Joyce", "Kafka"}))
        # Evaluate
        scores = loaded_model.evaluate(self.texts, self.labels)
        self.is_loss_and_accuracy(scores)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号