def batch_full_cosine_similarity(tensor1, tensor2):
"""
Expect 2 tensors tensor1 and tensor2 of dimension
(batch_size, seq_len_p, hidden) and (batch_size, seq_len_q, hidden)
respectively.
Return a matrix A of dimension (batch_size, seq_len_p, seq_len_q) with the
result of comparing each matrix to one another. A[k, :, :] represents the
cosine similarity between matrices P[k, :, :] and Q[k, :, :]. Then
A_k[i, j] is a scalar representing the cosine similarity between vectors
P_k[i, :] and Q_k[j, :]
"""
batch_size = tensor1.size(0)
seq_len_p = tensor1.size(1)
seq_len_q = tensor2.size(1)
hidden = tensor1.size(2)
assert batch_size == tensor2.size(0)
assert hidden == tensor2.size(2)
# -> (batch_size, seq_len_p, 1, hidden)
t1 = tensor1.unsqueeze(2)
# -> (batch_size, seq_len_p, seq_len_q, hidden)
t1 = t1.repeat(1, 1, seq_len_q, 1)
# -> (batch_size, 1, seq_len_q, hidden)
t2 = tensor2.unsqueeze(1)
# -> (batch_size, seq_len_p, seq_len_q, hidden)
t2 = t2.repeat(1, seq_len_p, 1, 1)
# -> (batch_size, seq_len_p, seq_len_q, hidden)
t1_x_t2 = torch.mul(t1, t2)
# -> (batch_size, seq_len_p, seq_len_q)
dotprod = torch.sum(t1_x_t2, 3).squeeze(3)
# norm1, norm2 and col_norm have dim (batch_size, seq_len_p, seq_len_q)
norm1 = torch.norm(t1, 2, 3)
norm2 = torch.norm(t2, 2, 3)
col_norm = torch.mul(norm1, norm2).squeeze(3)
return torch.div(dotprod, col_norm) # (batch_size, seq_len_p, seq_len_q)
评论列表
文章目录