def argmax_fun(W, indices, argmax_type='levi'):
"""
cosine: b* = argmax cosine(b*, b - a + a*)
levi: b* = argmax cos(b*,a*)cos(b*,b)/(cos(b*,a)+eps)
"""
if (argmax_type == 'levi'):
W = W / np.linalg.norm(W, axis=0)
words3 = W[:, indices]
cosines = ((words3.T).dot(W) + 1) / 2
obj = (cosines[1] * cosines[2]) / (cosines[0] + 1e-3)
pred_idx = np.argmax(obj)
elif (argmax_type == 'cosine'):
words3_vec = W[:, indices].sum(axis=1) - 2*W[:, indices[0]]
W = W / np.linalg.norm(W, axis=0)
words3_vec = words3_vec / np.linalg.norm(words3_vec)
cosines = (words3_vec.T).dot(W)
pred_idx = np.argmax(cosines)
return pred_idx
评论列表
文章目录