def loss(self, nodes, model, beta, chunksize=150):
"""
Forward function used to compute o3 loss
:param input_labels: of the node present in the batch
:param model: model containing all the shared data
:param beta: trade off param
"""
ret_loss = 0
for node_index in chunkize_serial(map(lambda x: model.vocab(x).index, nodes), chunksize):
input = model.node_embedding[node_index]
batch_loss = np.zeros(len(node_index), dtype=np.float32)
for com in range(model.k):
rd = multivariate_normal(model.centroid[com], model.covariance_mat[com])
# check if can be done as matrix operation
batch_loss += rd.logpdf(input).astype(np.float32) * model.pi[node_index, com]
ret_loss = abs(batch_loss.sum())
return ret_loss * (beta/model.k)
community_embeddings.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录