def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
projected_tensor_1 = torch.matmul(tensor_1, self._tensor_1_projection)
projected_tensor_2 = torch.matmul(tensor_2, self._tensor_2_projection)
# Here we split the last dimension of the tensors from (..., projected_dim) to
# (..., num_heads, projected_dim / num_heads), using tensor.view().
last_dim_size = projected_tensor_1.size(-1) // self.num_heads
new_shape = list(projected_tensor_1.size())[:-1] + [self.num_heads, last_dim_size]
split_tensor_1 = projected_tensor_1.view(*new_shape)
last_dim_size = projected_tensor_2.size(-1) // self.num_heads
new_shape = list(projected_tensor_2.size())[:-1] + [self.num_heads, last_dim_size]
split_tensor_2 = projected_tensor_2.view(*new_shape)
# And then we pass this off to our internal similarity function. Because the similarity
# functions don't care what dimension their input has, and only look at the last dimension,
# we don't need to do anything special here. It will just compute similarity on the
# projection dimension for each head, returning a tensor of shape (..., num_heads).
return self._internal_similarity(split_tensor_1, split_tensor_2)
评论列表
文章目录