def _create_state(self):
"""Prepare stateful variables modified during the recurrence."""
# Both the queue and the stack are flattened stack_size * batch_size
# tensors. `stack_size` many blocks of `batch_size` values
stack_shape = (self.stack_size * self.batch_size, self.model_dim)
self.stack = tf.Variable(tf.zeros(stack_shape, dtype=tf.float32),
trainable=False, name="stack")
self.queue = tf.Variable(tf.zeros((self.stack_size * self.batch_size,), dtype=tf.float32),
trainable=False, name="queue")
self.buff_cursors = tf.Variable(tf.zeros((self.batch_size,), dtype=tf.float32),
trainable=False, name="buff_cursors")
self.cursors = tf.Variable(tf.ones((self.batch_size,), dtype=tf.float32) * - 1,
trainable=False, name="cursors")
# TODO make parameterizable
self.tracking_value = tf.Variable(tf.zeros((self.batch_size, self.tracking_dim), dtype=tf.float32),
trainable=False, name="tracking_value")
# Create an Op which will (re-)initialize the auxiliary variables
# declared above.
self._aux_vars = [self.stack, self.queue, self.buff_cursors, self.cursors,
self.tracking_value]
self.variable_initializer = tf.initialize_variables(self._aux_vars)
评论列表
文章目录