def next_inputs(self, time, outputs, state, sample_ids, name=None):
with tf.name_scope(name, "ScheduledOutputTrainingHelperNextInputs",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
super(ScheduledOutputTrainingHelper, self).next_inputs(
time=time,
outputs=outputs,
state=state,
sample_ids=sample_ids,
name=name))
def maybe_sample():
"""Perform scheduled sampling."""
def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
"""Concatenate outputs with auxiliary inputs, if they exist."""
if self._auxiliary_input_tas is None:
return outputs_
next_time = time + 1
auxiliary_inputs = nest.map_structure(
lambda ta: ta.read(next_time), self._auxiliary_input_tas)
if indices is not None:
auxiliary_inputs = tf.gather_nd(
auxiliary_inputs, indices)
return nest.map_structure(
lambda x, y: tf.concat((x, y), -1),
outputs_, auxiliary_inputs)
if self._next_input_layer is None:
return tf.where(
sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
base_next_inputs)
where_sampling = tf.cast(
tf.where(sample_ids), tf.int32)
where_not_sampling = tf.cast(
tf.where(tf.logical_not(sample_ids)), tf.int32)
outputs_sampling = tf.gather_nd(outputs, where_sampling)
inputs_not_sampling = tf.gather_nd(base_next_inputs,
where_not_sampling)
sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
self._next_input_layer(outputs_sampling), where_sampling)
base_shape = tf.shape(base_next_inputs)
return (tf.scatter_nd(indices=where_sampling,
updates=sampled_next_inputs,
shape=base_shape)
+ tf.scatter_nd(indices=where_not_sampling,
updates=inputs_not_sampling,
shape=base_shape))
all_finished = tf.reduce_all(finished)
next_inputs = tf.cond(
all_finished, lambda: base_next_inputs, maybe_sample)
return (finished, next_inputs, state)
评论列表
文章目录