gcn_basis_stored.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:RelationPrediction 作者: MichSchli 项目源码 文件源码
def combine_messages(self, forward_messages, backward_messages, self_loop_messages, previous_code, mode='train'):
        mtr_f = self.get_graph().forward_incidence_matrix(normalization=('none', 'recalculated'))
        mtr_b = self.get_graph().backward_incidence_matrix(normalization=('none', 'recalculated'))

        if mode == 'train':
            forward_messages_comp = forward_messages - tf.nn.embedding_lookup(self.cached_messages_f, self.I)
            backward_messages_comp = backward_messages - tf.nn.embedding_lookup(self.cached_messages_b, self.I)

            with tf.control_dependencies([forward_messages, backward_messages]):
                self.f_upd = tf.scatter_update(self.cached_messages_f, self.I, forward_messages)
                self.b_upd = tf.scatter_update(self.cached_messages_b, self.I, backward_messages)

            collected_messages_f = tf.sparse_tensor_dense_matmul(mtr_f, forward_messages_comp)
            collected_messages_b = tf.sparse_tensor_dense_matmul(mtr_b, backward_messages_comp)

            new_embedding = collected_messages_f + collected_messages_b
            updated_vertex_embeddings = new_embedding + self.cached_vertex_embeddings

            with tf.control_dependencies([updated_vertex_embeddings]):
                self.v_upd = tf.assign(self.cached_vertex_embeddings, updated_vertex_embeddings)
        else:
            collected_messages_f = tf.sparse_tensor_dense_matmul(mtr_f, forward_messages)
            collected_messages_b = tf.sparse_tensor_dense_matmul(mtr_b, backward_messages)

            new_embedding = collected_messages_f + collected_messages_b
            updated_vertex_embeddings = new_embedding


        if self.use_nonlinearity:
            activated = tf.nn.relu(updated_vertex_embeddings + self_loop_messages)
        else:
            activated = updated_vertex_embeddings + self_loop_messages

        return activated
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号