def loss_net(self) -> Model:
"""Returns the network that yields a loss given both input spectrograms and labels. Used for training."""
input_batch = self._input_batch_input
label_batch = Input(name=Wav2Letter.InputNames.label_batch, shape=(None,), dtype='int32')
label_lengths = Input(name=Wav2Letter.InputNames.label_lengths, shape=(1,), dtype='int64')
asg_transition_probabilities_variable = backend.variable(value=self.asg_transition_probabilities,
name="asg_transition_probabilities")
asg_initial_probabilities_variable = backend.variable(value=self.asg_initial_probabilities,
name="asg_initial_probabilities")
# Since Keras doesn't currently support loss functions with extra parameters,
# we define a custom lambda layer yielding one single real-valued CTC loss given the grapheme probabilities:
loss_layer = Lambda(Wav2Letter._asg_lambda if self.use_asg else Wav2Letter._ctc_lambda,
name='asg_loss' if self.use_asg else 'ctc_loss',
output_shape=(1,),
arguments={"transition_probabilities": asg_transition_probabilities_variable,
"initial_probabilities": asg_initial_probabilities_variable} if self.use_asg else None)
# ([asg_transition_probabilities_variable, asg_initial_probabilities_variable] if self.use_asg else [])
# This loss layer is placed atop the predictive network and provided with additional arguments,
# namely the label batch and prediction/label sequence lengths:
loss = loss_layer(
[self.predictive_net(input_batch), label_batch, self._prediction_lengths_input, label_lengths])
loss_net = Model(inputs=[input_batch, label_batch, self._prediction_lengths_input, label_lengths],
outputs=[loss])
# Since loss is already calculated in the last layer of the net, we just pass through the results here.
# The loss dummy labels have to be given to satify the Keras API.
loss_net.compile(loss=lambda dummy_labels, ctc_loss: ctc_loss, optimizer=self.optimizer)
return loss_net
评论列表
文章目录