def compute_tree(self, emb_x, tree):
self.recursive_unit = self.create_recursive_unit()
self.leaf_unit = self.create_leaf_unit()
num_nodes = tree.shape[0] # num internal nodes
num_leaves = self.num_words - num_nodes
# compute leaf hidden states
(leaf_h, leaf_c), _ = theano.map(
fn=self.leaf_unit,
sequences=[emb_x[:num_leaves]])
if self.irregular_tree:
init_node_h = T.concatenate([leaf_h, leaf_h], axis=0)
init_node_c = T.concatenate([leaf_c, leaf_c], axis=0)
else:
init_node_h = leaf_h
init_node_c = leaf_c
# use recurrence to compute internal node hidden states
def _recurrence(cur_emb, node_info, t, node_h, node_c, last_h):
child_exists = node_info > -1
offset = num_leaves * int(self.irregular_tree) - child_exists * t
child_h = node_h[node_info + offset] * child_exists.dimshuffle(0, 'x')
child_c = node_c[node_info + offset] * child_exists.dimshuffle(0, 'x')
parent_h, parent_c = self.recursive_unit(cur_emb, child_h, child_c, child_exists)
node_h = T.concatenate([node_h,
parent_h.reshape([1, self.hidden_dim])])
node_c = T.concatenate([node_c,
parent_c.reshape([1, self.hidden_dim])])
return node_h[1:], node_c[1:], parent_h
dummy = theano.shared(self.init_vector([self.hidden_dim]))
(_, _, parent_h), _ = theano.scan(
fn=_recurrence,
outputs_info=[init_node_h, init_node_c, dummy],
sequences=[emb_x[num_leaves:], tree, T.arange(num_nodes)],
n_steps=num_nodes)
return T.concatenate([leaf_h, parent_h], axis=0)
评论列表
文章目录