def _compute_vert_context_soft(self, edge_factor, vert_factor, reuse=False):
"""
attention-based vertex(node) message pooling
"""
out_edge = utils.pad_and_gather(edge_factor, self.edge_pair_mask_inds[:,0])
in_edge = utils.pad_and_gather(edge_factor, self.edge_pair_mask_inds[:,1])
# gather correspounding vert factors
vert_factor_gathered = tf.gather(vert_factor, self.edge_pair_segment_inds)
# concat outgoing edges and ingoing edges with gathered vert_factors
out_edge_w_input = tf.concat(concat_dim=1, values=[out_edge, vert_factor_gathered])
in_edge_w_input = tf.concat(concat_dim=1, values=[in_edge, vert_factor_gathered])
# compute compatibility scores
(self.feed(out_edge_w_input)
.fc(1, relu=False, reuse=reuse, name='out_edge_w_fc')
.sigmoid(name='out_edge_score'))
(self.feed(in_edge_w_input)
.fc(1, relu=False, reuse=reuse, name='in_edge_w_fc')
.sigmoid(name='in_edge_score'))
out_edge_w = self.get_output('out_edge_score')
in_edge_w = self.get_output('in_edge_score')
# weight the edge factors with computed weigths
out_edge_weighted = tf.mul(out_edge, out_edge_w)
in_edge_weighted = tf.mul(in_edge, in_edge_w)
edge_sum = out_edge_weighted + in_edge_weighted
vert_ctx = tf.segment_sum(edge_sum, self.edge_pair_segment_inds)
return vert_ctx
评论列表
文章目录