net.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:speechless 作者: JuliusKunze 项目源码 文件源码
def test_and_predict_batch(self, labeled_spectrogram_batch: List[LabeledSpectrogram]) -> ExpectationsVsPredictions:
        input_by_name, dummy_labels = self._inputs_for_loss_net(labeled_spectrogram_batch)

        predicted_graphemes, loss_batch = self.get_predicted_graphemes_and_loss_batch(
            [input_by_name[input.name.split(":")[0]] for input in self.loss_net.inputs] + [self.prediction_phase_flag])

        # blank labels are returned as -1 by tensorflow:
        predicted_graphemes[predicted_graphemes < 0] = self.grapheme_encoding.ctc_blank

        prediction_lengths = list(numpy.squeeze(input_by_name[Wav2Letter.InputNames.prediction_lengths], axis=1))
        losses = list(numpy.squeeze(loss_batch, axis=1))

        # merge was already done by tensorflow, so we disable it here:
        predictions = self.grapheme_encoding.decode_grapheme_batch(predicted_graphemes, prediction_lengths,
                                                                   merge_repeated=False)

        return ExpectationsVsPredictions(
            [ExpectationVsPrediction(predicted=predicted, expected=expected, loss=loss) for predicted, expected, loss in
             zip(predictions, (e.label for e in labeled_spectrogram_batch), losses)])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号