def next_inputs(self, time, outputs, state, sample_ids, name=None):
with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
super(ScheduledOutputTrainingHelper, self).next_inputs(
time=time,
outputs=outputs,
state=state,
sample_ids=sample_ids,
name=name))
def maybe_sample():
"""Perform scheduled sampling."""
def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
"""Concatenate outputs with auxiliary inputs, if they exist."""
if self._auxiliary_input_tas is None:
return outputs_
next_time = time + 1
auxiliary_inputs = nest.map_structure(
lambda ta: ta.read(next_time), self._auxiliary_input_tas)
if indices is not None:
auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices)
return nest.map_structure(
lambda x, y: array_ops.concat((x, y), -1),
outputs_, auxiliary_inputs)
if self._next_input_layer is None:
return array_ops.where(
sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
base_next_inputs)
where_sampling = math_ops.cast(
array_ops.where(sample_ids), dtypes.int32)
where_not_sampling = math_ops.cast(
array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
where_not_sampling)
sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
self._next_input_layer(outputs_sampling), where_sampling)
base_shape = array_ops.shape(base_next_inputs)
return (array_ops.scatter_nd(indices=where_sampling,
updates=sampled_next_inputs,
shape=base_shape)
+ array_ops.scatter_nd(indices=where_not_sampling,
updates=inputs_not_sampling,
shape=base_shape))
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: base_next_inputs, maybe_sample)
return (finished, next_inputs, state)
评论列表
文章目录