main.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:speechless 作者: JuliusKunze 项目源码 文件源码
def validate_to_csv(model_name: str, last_epoch: int,
                        configuration: Configuration = Configuration.german(),
                        step_count=10, first_epoch: int = 0,
                        csv_directory: Path = configuration.default_data_directories.test_results_directory) -> List[
        Tuple[int, ExpectationsVsPredictionsInGroupedBatches]]:

        step_size = (last_epoch - first_epoch) / (step_count - 1)

        epochs = distinct(list(int(first_epoch + index * step_size) for index in range(step_count)))
        log("Testing model {} on epochs {}.".format(model_name, epochs))

        model = configuration.load_model(model_name, last_epoch,
                                         allowed_characters_for_loaded_model=configuration.allowed_characters,
                                         use_kenlm=True,
                                         language_model_name_extension="-incl-trans")

        def get_result(epoch: int) -> ExpectationsVsPredictionsInGroupedBatches:
            log("Testing epoch {}.".format(epoch))

            model.load_weights(
                allowed_characters_for_loaded_model=configuration.allowed_characters,
                load_model_from_directory=configuration.directories.nets_base_directory / model_name, load_epoch=epoch)

            return configuration.test_model_grouped_by_loaded_corpus_name(model)

        results_with_epochs = []

        csv_file = csv_directory / "{}.csv".format(model_name + "-incl-trans")
        import csv
        with csv_file.open('w', encoding='utf8') as opened_csv:
            writer = csv.writer(opened_csv, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)

            for epoch in epochs:
                result = get_result(epoch)
                writer.writerow((epoch, result.average_loss, result.average_letter_error_rate,
                                 result.average_word_error_rate, result.average_letter_error_count,
                                 result.average_word_error_count))

        return results_with_epochs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号