def tf_process(self, tensor):
# or just always the same?
tf.assert_equal(x=tf.shape(input=tensor)[0], y=1)
states_buffer = tf.get_variable(
name='states-buffer',
shape=((self.length,) + util.shape(tensor)[1:]),
dtype=tensor.dtype,
trainable=False
)
index = tf.get_variable(
name='index',
dtype=util.tf_dtype('int'),
initializer=-1,
trainable=False
)
assignment = tf.cond(
pred=tf.equal(x=index, y=-1),
true_fn=(lambda: tf.assign(
ref=states_buffer,
value=tf.tile(
input=tensor,
multiples=((self.length,) + tuple(1 for _ in range(util.rank(tensor) - 1)))
)
)),
false_fn=(lambda: tf.assign(ref=states_buffer[index], value=tensor[0]))
)
with tf.control_dependencies(control_inputs=(assignment,)):
previous_states = [states_buffer[(index - n - 1) % self.length] for n in range(self.length)]
assignment = tf.assign(ref=index, value=((tf.maximum(x=index, y=0) + 1) % self.length))
with tf.control_dependencies(control_inputs=(assignment,)):
return tf.expand_dims(input=tf.concat(values=previous_states, axis=-1), axis=0)
评论列表
文章目录