def dot_nd(query, candidates):
"""
Perform a dot product between a query and n-dimensional candidates.
Args:
query (Variable): A vector to query, whose size is
(query_dim,)
candidates (Variable): A n-dimensional tensor to be multiplied
by query, whose size is (d0, d1, ..., dn, query_dim)
Returns:
output: The result of the dot product, whose size is
(d0, d1, ..., dn)
"""
cands_size = candidates.size()
cands_flat = candidates.view(-1, cands_size[-1])
output_flat = torch.mv(cands_flat, query)
output = output_flat.view(*cands_size[:-1])
return output
评论列表
文章目录