def reconstruction_loss(self, x_input, x_target, x_length, z=None):
"""Reconstruction loss calculation.
Args:
x_input: Batch of decoder input sequences for teacher forcing, sized
`[batch_size, max(x_length), output_depth]`.
x_target: Batch of expected output sequences to compute loss against,
sized `[batch_size, max(x_length), output_depth]`.
x_length: Length of input/output sequences, sized `[batch_size]`.
z: (Optional) Latent vectors. Required if model is conditional. Sized
`[n, z_size]`.
Returns:
r_loss: The reconstruction loss for each sequence in the batch.
metric_map: Map from metric name to tf.metrics return values for logging.
truths: Ground truth labels, sized
"""
batch_size = x_input.shape[0].value
has_z = z is not None
z = tf.zeros([batch_size, 0]) if z is None else z
repeated_z = tf.tile(
tf.expand_dims(z, axis=1), [1, tf.shape(x_input)[1], 1])
sampling_probability_static = tensor_util.constant_value(
self._sampling_probability)
if sampling_probability_static == 0.0:
# Use teacher forcing.
x_input = tf.concat([x_input, repeated_z], axis=2)
helper = tf.contrib.seq2seq.TrainingHelper(x_input, x_length)
else:
# Use scheduled sampling.
helper = tf.contrib.seq2seq.ScheduledOutputTrainingHelper(
inputs=x_input,
sequence_length=x_length,
auxiliary_inputs=repeated_z if has_z else None,
sampling_probability=self._sampling_probability,
next_inputs_fn=self._sample)
decoder_outputs = self._decode(batch_size, helper=helper, z=z)
flat_x_target = flatten_maybe_padded_sequences(x_target, x_length)
flat_rnn_output = flatten_maybe_padded_sequences(
decoder_outputs.rnn_output, x_length)
r_loss, metric_map, truths, predictions = self._flat_reconstruction_loss(
flat_x_target, flat_rnn_output)
# Sum loss over sequences.
cum_x_len = tf.concat([(0,), tf.cumsum(x_length)], axis=0)
r_losses = []
for i in range(batch_size):
b, e = cum_x_len[i], cum_x_len[i + 1]
r_losses.append(tf.reduce_sum(r_loss[b:e]))
r_loss = tf.stack(r_losses)
return r_loss, metric_map, truths, predictions
评论列表
文章目录