def init_thin_stack(batch_size, max_num_concepts):
"""Initializes the thin stack.
Returns:
thin_stack: Tensor with the stack content.
thin_stack_head_next: Index pointers to element after stack head.
"""
# Stack initialized to -1, points to initial state.
thin_stack = -tf.ones(tf.pack([batch_size, max_num_concepts]),
dtype=tf.int32)
# Reshape to ensure dimension 1 is known.
thin_stack = tf.reshape(thin_stack, [-1, max_num_concepts])
# Set to 0 at position 0.
inds = tf.transpose(tf.to_int64(tf.pack(
[tf.range(batch_size), tf.zeros(tf.pack([batch_size]), dtype=tf.int32)])))
delta = tf.SparseTensor(inds, tf.ones(tf.pack([batch_size]), dtype=tf.int32),
tf.pack([tf.to_int64(batch_size), max_num_concepts]))
new_thin_stack = thin_stack + tf.sparse_tensor_to_dense(delta)
# Position 0 is for empty stack; position after head always >= 1.
thin_stack_head_next = tf.ones(tf.pack([batch_size]),
dtype=tf.int32)
return new_thin_stack, thin_stack_head_next
评论列表
文章目录