community_embeddings.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:nodeembedding-to-communityembedding 作者: andompesta 项目源码 文件源码
def loss(self, nodes, model, beta, chunksize=150):
        """
        Forward function used to compute o3 loss
        :param input_labels: of the node present in the batch
        :param model: model containing all the shared data
        :param beta: trade off param
        """
        ret_loss = 0
        for node_index in chunkize_serial(map(lambda x: model.vocab(x).index, nodes), chunksize):
            input = model.node_embedding[node_index]

            batch_loss = np.zeros(len(node_index), dtype=np.float32)
            for com in range(model.k):
                rd = multivariate_normal(model.centroid[com], model.covariance_mat[com])
                # check if can be done as matrix operation
                batch_loss += rd.logpdf(input).astype(np.float32) * model.pi[node_index, com]

            ret_loss = abs(batch_loss.sum())

        return ret_loss * (beta/model.k)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号