def conv_step(nodes, children, feature_size, w_t, w_r, w_l, b_conv):
"""Convolve a batch of nodes and children.
Lots of high dimensional tensors in this function. Intuitively it makes
more sense if we did this work with while loops, but computationally this
is more efficient. Don't try to wrap your head around all the tensor dot
products, just follow the trail of dimensions.
"""
with tf.name_scope('conv_step'):
# nodes is shape (batch_size x max_tree_size x feature_size)
# children is shape (batch_size x max_tree_size x max_children)
with tf.name_scope('trees'):
# children_vectors will have shape
# (batch_size x max_tree_size x max_children x feature_size)
children_vectors = children_tensor(nodes, children, feature_size)
# add a 4th dimension to the nodes tensor
nodes = tf.expand_dims(nodes, axis=2)
# tree_tensor is shape
# (batch_size x max_tree_size x max_children + 1 x feature_size)
tree_tensor = tf.concat([nodes, children_vectors], axis=2, name='trees')
with tf.name_scope('coefficients'):
# coefficient tensors are shape (batch_size x max_tree_size x max_children + 1)
c_t = eta_t(children)
c_r = eta_r(children, c_t)
c_l = eta_l(children, c_t, c_r)
# concatenate the position coefficients into a tensor
# (batch_size x max_tree_size x max_children + 1 x 3)
coef = tf.stack([c_t, c_r, c_l], axis=3, name='coef')
with tf.name_scope('weights'):
# stack weight matrices on top to make a weight tensor
# (3, feature_size, output_size)
weights = tf.stack([w_t, w_r, w_l], axis=0)
with tf.name_scope('combine'):
batch_size = tf.shape(children)[0]
max_tree_size = tf.shape(children)[1]
max_children = tf.shape(children)[2]
# reshape for matrix multiplication
x = batch_size * max_tree_size
y = max_children + 1
result = tf.reshape(tree_tensor, (x, y, feature_size))
coef = tf.reshape(coef, (x, y, 3))
result = tf.matmul(result, coef, transpose_a=True)
result = tf.reshape(result, (batch_size, max_tree_size, 3, feature_size))
# output is (batch_size, max_tree_size, output_size)
result = tf.tensordot(result, weights, [[2, 3], [0, 1]])
# output is (batch_size, max_tree_size, output_size)
return tf.nn.tanh(result + b_conv, name='conv')
评论列表
文章目录