def cosine_sim2d(k, M):
# k: (nb_samples, memory_width)
# M: (nb_samples, memory_dim, memory_width)
# norms of keys and memories
k_norm = T.sqrt(T.sum(T.sqr(k), 1)) + 1e-5 # (nb_samples,)
M_norm = T.sqrt(T.sum(T.sqr(M), 2)) + 1e-5 # (nb_samples, memory_dim,)
k = k[:, None, :] # (nb_samples, 1, memory_width)
k_norm = k_norm[:, None] # (nb_samples, 1)
sim = T.sum(k * M, axis=2) # (nb_samples, memory_dim,)
sim /= k_norm * M_norm # (nb_samples, memory_dim,)
return sim
评论列表
文章目录