def next_inp(self, time, output):
"""Returns the next input.
Arguments:
time: a `int` or unit `Tensor` representing the current timestep.
output: a `2D Tensor` of shape `[batch_size, output_size]` representing
the current output.
*NOTE* that at time `t+1` the desired decoder input is the output
from the previous step, `t`, it means that at timestep `t` the next
input is the desired output for the very same timestep, if decoder
inputs have been provided -- otherwise is just the current output.
"""
if self._inputs_ta:
output = tf.cond(
time < self._inputs_ta.size(),
lambda: self._inputs_ta.read(time),
lambda: self.zero_output()) # pylint: disable=W0108
next_inp = ops.fit(output, self._inp_size)
return next_inp
评论列表
文章目录