def eta_r(children, t_coef):
"""Compute weight matrix for how much each vector belogs to the 'right'"""
with tf.name_scope('coef_r'):
# children is shape (batch_size x max_tree_size x max_children)
children = tf.cast(children, tf.float32)
batch_size = tf.shape(children)[0]
max_tree_size = tf.shape(children)[1]
max_children = tf.shape(children)[2]
# num_siblings is shape (batch_size x max_tree_size x 1)
num_siblings = tf.cast(
tf.count_nonzero(children, axis=2, keep_dims=True),
dtype=tf.float32
)
# num_siblings is shape (batch_size x max_tree_size x max_children + 1)
num_siblings = tf.tile(
num_siblings, [1, 1, max_children + 1], name='num_siblings'
)
# creates a mask of 1's and 0's where 1 means there is a child there
# has shape (batch_size x max_tree_size x max_children + 1)
mask = tf.concat(
[tf.zeros((batch_size, max_tree_size, 1)),
tf.minimum(children, tf.ones(tf.shape(children)))],
axis=2, name='mask'
)
# child indices for every tree (batch_size x max_tree_size x max_children + 1)
child_indices = tf.multiply(tf.tile(
tf.expand_dims(
tf.expand_dims(
tf.range(-1.0, tf.cast(max_children, tf.float32), 1.0, dtype=tf.float32),
axis=0
),
axis=0
),
[batch_size, max_tree_size, 1]
), mask, name='child_indices')
# weights for every tree node in the case that num_siblings = 0
# shape is (batch_size x max_tree_size x max_children + 1)
singles = tf.concat(
[tf.zeros((batch_size, max_tree_size, 1)),
tf.fill((batch_size, max_tree_size, 1), 0.5),
tf.zeros((batch_size, max_tree_size, max_children - 1))],
axis=2, name='singles')
# eta_r is shape (batch_size x max_tree_size x max_children + 1)
return tf.where(
tf.equal(num_siblings, 1.0),
# avoid division by 0 when num_siblings == 1
singles,
# the normal case where num_siblings != 1
tf.multiply((1.0 - t_coef), tf.divide(child_indices, num_siblings - 1.0)),
name='coef_r'
)
评论列表
文章目录