def children_tensor(nodes, children, feature_size):
"""Build the children tensor from the input nodes and child lookup."""
with tf.name_scope('children_tensor'):
max_children = tf.shape(children)[2]
batch_size = tf.shape(nodes)[0]
num_nodes = tf.shape(nodes)[1]
# replace the root node with the zero vector so lookups for the 0th
# vector return 0 instead of the root vector
# zero_vecs is (batch_size, num_nodes, 1)
zero_vecs = tf.zeros((batch_size, 1, feature_size))
# vector_lookup is (batch_size x num_nodes x feature_size)
vector_lookup = tf.concat([zero_vecs, nodes[:, 1:, :]], axis=1)
# children is (batch_size x num_nodes x num_children x 1)
children = tf.expand_dims(children, axis=3)
# prepend the batch indices to the 4th dimension of children
# batch_indices is (batch_size x 1 x 1 x 1)
batch_indices = tf.reshape(tf.range(0, batch_size), (batch_size, 1, 1, 1))
# batch_indices is (batch_size x num_nodes x num_children x 1)
batch_indices = tf.tile(batch_indices, [1, num_nodes, max_children, 1])
# children is (batch_size x num_nodes x num_children x 2)
children = tf.concat([batch_indices, children], axis=3)
# output will have shape (batch_size x num_nodes x num_children x feature_size)
# NOTE: tf < 1.1 contains a bug that makes backprop not work for this!
return tf.gather_nd(vector_lookup, children, name='children')
评论列表
文章目录