def next_inputs(self, time, outputs, state, sample_ids, name=None):
with tf.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
super(ScheduledEmbeddingTrainingHelper, self).next_inputs(
time=time,
outputs=outputs,
state=state,
sample_ids=sample_ids,
name=name))
def maybe_sample():
"""Perform scheduled sampling."""
where_sampling = tf.cast(
tf.where(sample_ids > -1), tf.int32)
where_not_sampling = tf.cast(
tf.where(sample_ids <= -1), tf.int32)
where_sampling_flat = tf.reshape(where_sampling, [-1])
where_not_sampling_flat = tf.reshape(
where_not_sampling, [-1])
sample_ids_sampling = tf.gather(
sample_ids, where_sampling_flat)
inputs_not_sampling = tf.gather(
base_next_inputs, where_not_sampling_flat)
sampled_next_inputs = self._embedding_fn(sample_ids_sampling)
base_shape = tf.shape(base_next_inputs)
return (tf.scatter_nd(indices=where_sampling,
updates=sampled_next_inputs,
shape=base_shape)
+ tf.scatter_nd(indices=where_not_sampling,
updates=inputs_not_sampling,
shape=base_shape))
all_finished = tf.reduce_all(finished)
next_inputs = tf.cond(
all_finished, lambda: base_next_inputs, maybe_sample)
return (finished, next_inputs, state)
评论列表
文章目录