def combine_messages(self, forward_messages, backward_messages, self_loop_messages, previous_code, mode='train'):
mtr_f = self.get_graph().forward_incidence_matrix(normalization=('none', 'recalculated'))
mtr_b = self.get_graph().backward_incidence_matrix(normalization=('none', 'recalculated'))
if mode == 'train':
forward_messages_comp = forward_messages - tf.nn.embedding_lookup(self.cached_messages_f, self.I)
backward_messages_comp = backward_messages - tf.nn.embedding_lookup(self.cached_messages_b, self.I)
with tf.control_dependencies([forward_messages, backward_messages]):
self.f_upd = tf.scatter_update(self.cached_messages_f, self.I, forward_messages)
self.b_upd = tf.scatter_update(self.cached_messages_b, self.I, backward_messages)
collected_messages_f = tf.sparse_tensor_dense_matmul(mtr_f, forward_messages_comp)
collected_messages_b = tf.sparse_tensor_dense_matmul(mtr_b, backward_messages_comp)
new_embedding = collected_messages_f + collected_messages_b
updated_vertex_embeddings = new_embedding + self.cached_vertex_embeddings
with tf.control_dependencies([updated_vertex_embeddings]):
self.v_upd = tf.assign(self.cached_vertex_embeddings, updated_vertex_embeddings)
else:
collected_messages_f = tf.sparse_tensor_dense_matmul(mtr_f, forward_messages)
collected_messages_b = tf.sparse_tensor_dense_matmul(mtr_b, backward_messages)
new_embedding = collected_messages_f + collected_messages_b
updated_vertex_embeddings = new_embedding
if self.use_nonlinearity:
activated = tf.nn.relu(updated_vertex_embeddings + self_loop_messages)
else:
activated = updated_vertex_embeddings + self_loop_messages
return activated
评论列表
文章目录