def test_and_predict_batch(self, labeled_spectrogram_batch: List[LabeledSpectrogram]) -> ExpectationsVsPredictions:
input_by_name, dummy_labels = self._inputs_for_loss_net(labeled_spectrogram_batch)
predicted_graphemes, loss_batch = self.get_predicted_graphemes_and_loss_batch(
[input_by_name[input.name.split(":")[0]] for input in self.loss_net.inputs] + [self.prediction_phase_flag])
# blank labels are returned as -1 by tensorflow:
predicted_graphemes[predicted_graphemes < 0] = self.grapheme_encoding.ctc_blank
prediction_lengths = list(numpy.squeeze(input_by_name[Wav2Letter.InputNames.prediction_lengths], axis=1))
losses = list(numpy.squeeze(loss_batch, axis=1))
# merge was already done by tensorflow, so we disable it here:
predictions = self.grapheme_encoding.decode_grapheme_batch(predicted_graphemes, prediction_lengths,
merge_repeated=False)
return ExpectationsVsPredictions(
[ExpectationVsPrediction(predicted=predicted, expected=expected, loss=loss) for predicted, expected, loss in
zip(predictions, (e.label for e in labeled_spectrogram_batch), losses)])
评论列表
文章目录