def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
projected_tensor_1 = torch.matmul(tensor_1, self._tensor_1_projection)
projected_tensor_2 = torch.matmul(tensor_2, self._tensor_2_projection)
# Here we split the last dimension of the tensors from (..., projected_dim) to
# (..., num_heads, projected_dim / num_heads), using tensor.view().
last_dim_size = projected_tensor_1.size(-1) // self.num_heads
new_shape = list(projected_tensor_1.size())[:-1] + [self.num_heads, last_dim_size]
split_tensor_1 = projected_tensor_1.view(*new_shape)
last_dim_size = projected_tensor_2.size(-1) // self.num_heads
new_shape = list(projected_tensor_2.size())[:-1] + [self.num_heads, last_dim_size]
split_tensor_2 = projected_tensor_2.view(*new_shape)
# And then we pass this off to our internal similarity function. Because the similarity
# functions don't care what dimension their input has, and only look at the last dimension,
# we don't need to do anything special here. It will just compute similarity on the
# projection dimension for each head, returning a tensor of shape (..., num_heads).
return self._internal_similarity(split_tensor_1, split_tensor_2)
python类matmul()的实例源码
def backward(self, grad_output):
means, = self.saved_tensors
T = self.R.shape[0]
dim = means.dim()
# Add batch axis if necessary
if dim == 2:
T_, D = means.shape
B = 1
grad_output = grad_output.view(B, T, -1)
else:
B, T_, D = means.shape
grad = torch.matmul(self.R.transpose(0, 1), grad_output)
reshaped = not (T == T_)
if not reshaped:
grad = grad.view(B, self.num_windows, T, -1).transpose(
1, 2).contiguous().view(B, T, D)
if dim == 2:
return grad.view(-1, D)
return grad
def test_toeplitz_matmul_batch():
cols = torch.Tensor([
[1, 6, 4, 5],
[2, 3, 1, 0],
[1, 2, 3, 1],
])
rows = torch.Tensor([
[1, 2, 1, 1],
[2, 0, 0, 1],
[1, 5, 1, 0],
])
rhs_mats = torch.randn(3, 4, 2)
# Actual
lhs_mats = torch.zeros(3, 4, 4)
for i, (col, row) in enumerate(zip(cols, rows)):
lhs_mats[i].copy_(utils.toeplitz.toeplitz(col, row))
actual = torch.matmul(lhs_mats, rhs_mats)
# Fast toeplitz
res = utils.toeplitz.toeplitz_matmul(cols, rows, rhs_mats)
assert utils.approx_equal(res, actual)
def diag(self):
batch_size, n_data, n_interp = self.left_interp_indices.size()
# Batch compute the non-zero values of the outer products w_left^k w_right^k^T
left_interp_values = self.left_interp_values.unsqueeze(3)
right_interp_values = self.right_interp_values.unsqueeze(2)
interp_values = torch.matmul(left_interp_values, right_interp_values)
# Batch compute Toeplitz values that will be non-zero for row k
left_interp_indices = self.left_interp_indices.unsqueeze(3).expand(batch_size, n_data, n_interp, n_interp)
left_interp_indices = left_interp_indices.contiguous()
right_interp_indices = self.right_interp_indices.unsqueeze(2).expand(batch_size, n_data, n_interp, n_interp)
right_interp_indices = right_interp_indices.contiguous()
batch_interp_indices = Variable(left_interp_indices.data.new(batch_size))
torch.arange(0, batch_size, out=batch_interp_indices.data)
batch_interp_indices = batch_interp_indices.view(batch_size, 1, 1, 1)
batch_interp_indices = batch_interp_indices.expand(batch_size, n_data, n_interp, n_interp).contiguous()
base_var_vals = self.base_lazy_variable._batch_get_indices(batch_interp_indices.view(-1),
left_interp_indices.view(-1),
right_interp_indices.view(-1))
base_var_vals = base_var_vals.view(left_interp_indices.size())
diag = (interp_values * base_var_vals).sum(3).sum(2).sum(0)
return diag
def diag(self):
n_data, n_interp = self.left_interp_indices.size()
# Batch compute the non-zero values of the outer products w_left^k w_right^k^T
left_interp_values = self.left_interp_values.unsqueeze(2)
right_interp_values = self.right_interp_values.unsqueeze(1)
interp_values = torch.matmul(left_interp_values, right_interp_values)
# Batch compute Toeplitz values that will be non-zero for row k
left_interp_indices = self.left_interp_indices.unsqueeze(2).expand(n_data, n_interp, n_interp).contiguous()
right_interp_indices = self.right_interp_indices.unsqueeze(1).expand(n_data, n_interp, n_interp).contiguous()
base_var_vals = self.base_lazy_variable._get_indices(left_interp_indices.view(-1),
right_interp_indices.view(-1))
base_var_vals = base_var_vals.view(left_interp_indices.size())
diag = (interp_values * base_var_vals).sum(2).sum(1)
return diag
def test_matmul_out(self):
def check_matmul(size1, size2):
a = torch.randn(size1)
b = torch.randn(size2)
expected = torch.matmul(a, b)
out = torch.Tensor(expected.size()).zero_()
# make output non-contiguous
out = out.transpose(-1, -2).contiguous().transpose(-1, -2)
self.assertFalse(out.is_contiguous())
torch.matmul(a, b, out=out)
self.assertEqual(expected, out)
check_matmul((2, 3, 4), (2, 4, 5))
check_matmul((2, 3, 4), (4, 5))
def predict(self, x):
batch_size, dims = x.size()
query = F.normalize(self.query_proj(x), dim=1)
# Find the k-nearest neighbors of the query
scores = torch.matmul(query, torch.t(self.keys_var))
cosine_similarity, topk_indices_var = torch.topk(scores, self.top_k, dim=1)
# softmax of cosine similarities - embedding
softmax_score = F.softmax(self.softmax_temperature * cosine_similarity)
# retrive memory values - prediction
y_hat_indices = topk_indices_var.data[:, 0]
y_hat = self.values[y_hat_indices]
return y_hat, softmax_score
def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
intermediate = torch.matmul(tensor_1, self._weight_matrix)
result = (intermediate * tensor_2).sum(dim=-1)
return self._activation(result + self._bias)
def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
combined_tensors = util.combine_tensors(self._combination, [tensor_1, tensor_2])
dot_product = torch.matmul(combined_tensors, self._weight_vector)
return self._activation(dot_product + self._bias)
def forward(self, means):
# TODO: remove this
self.save_for_backward(means)
T = self.R.shape[0]
dim = means.dim()
# Add batch axis if necessary
if dim == 2:
T_, D = means.shape
B = 1
means = means.view(B, T_, D)
else:
B, T_, D = means.shape
# Check if means has proper shape
reshaped = not (T == T_)
if not reshaped:
static_dim = means.shape[-1] // self.num_windows
reshaped_means = means.contiguous().view(
B, T, self.num_windows, -1).transpose(
1, 2).contiguous().view(B, -1, static_dim)
else:
static_dim = means.shape[-1]
reshaped_means = means
out = torch.matmul(self.R, reshaped_means)
if dim == 2:
return out.view(-1, static_dim)
return out
def pairwise_ranking_loss(margin, x, v):
zero = torch.zeros(1)
diag_margin = margin * torch.eye(x.size(0))
if not args.no_cuda:
zero, diag_margin = zero.cuda(), diag_margin.cuda()
zero, diag_margin = Variable(zero), Variable(diag_margin)
x = x / torch.norm(x, 2, 1, keepdim=True)
v = v / torch.norm(v, 2, 1, keepdim=True)
prod = torch.matmul(x, v.transpose(0, 1))
diag = torch.diag(prod)
for_x = torch.max(zero, margin - torch.unsqueeze(diag, 1) + prod) - diag_margin
for_v = torch.max(zero, margin - torch.unsqueeze(diag, 0) + prod) - diag_margin
return (torch.sum(for_x) + torch.sum(for_v)) / x.size(0)
def matmul(self, other):
"""Matrix product of two tensors.
See :func:`torch.matmul`."""
return torch.matmul(self, other)
def __matmul__(self, other):
if not torch.is_tensor(other):
return NotImplemented
return self.matmul(other)
def matmul(self, other):
return torch.matmul(self, other)
def __matmul__(self, other):
if not isinstance(other, Variable):
return NotImplemented
return self.matmul(other)
def test_toeplitz_matmul():
col = torch.Tensor([1, 6, 4, 5])
row = torch.Tensor([1, 2, 1, 1])
rhs_mat = torch.randn(4, 2)
# Actual
lhs_mat = utils.toeplitz.toeplitz(col, row)
actual = torch.matmul(lhs_mat, rhs_mat)
# Fast toeplitz
res = utils.toeplitz.toeplitz_matmul(col, row, rhs_mat)
assert utils.approx_equal(res, actual)
def test_toeplitz_matmul_batchmat():
col = torch.Tensor([1, 6, 4, 5])
row = torch.Tensor([1, 2, 1, 1])
rhs_mat = torch.randn(3, 4, 2)
# Actual
lhs_mat = utils.toeplitz.toeplitz(col, row)
actual = torch.matmul(lhs_mat.unsqueeze(0), rhs_mat)
# Fast toeplitz
res = utils.toeplitz.toeplitz_matmul(col.unsqueeze(0), row.unsqueeze(0), rhs_mat)
assert utils.approx_equal(res, actual)
def test_left_interp_on_a_vector():
vector = torch.randn(6)
res = left_interp(interp_indices, interp_values, Variable(vector)).data
actual = torch.matmul(interp_matrix, vector)
assert approx_equal(res, actual)
def test_batch_left_interp_on_a_vector():
vector = torch.randn(6)
actual = torch.matmul(batch_interp_matrix, vector.unsqueeze(-1).unsqueeze(0)).squeeze(0)
res = left_interp(batch_interp_indices, batch_interp_values, Variable(vector)).data
assert approx_equal(res, actual)
def test_batch_left_interp_on_a_matrix():
batch_matrix = torch.randn(6, 3)
res = left_interp(batch_interp_indices, batch_interp_values, Variable(batch_matrix)).data
actual = torch.matmul(batch_interp_matrix, batch_matrix.unsqueeze(0))
assert approx_equal(res, actual)
def test_batch_left_interp_on_a_batch_matrix():
batch_matrix = torch.randn(2, 6, 3)
res = left_interp(batch_interp_indices, batch_interp_values, Variable(batch_matrix)).data
actual = torch.matmul(batch_interp_matrix, batch_matrix)
assert approx_equal(res, actual)
def test_forward_batch():
i = torch.LongTensor([[0, 0, 0, 1, 1, 1],
[0, 1, 1, 0, 1, 1],
[2, 0, 2, 2, 0, 2]])
v = torch.FloatTensor([3, 4, 5, 6, 7, 8])
sparse = torch.sparse.FloatTensor(i, v, torch.Size([2, 2, 3]))
dense = Variable(torch.randn(2, 3, 3))
res = gpytorch.dsmm(Variable(sparse), dense)
actual = torch.matmul(Variable(sparse.to_dense()), dense)
assert(torch.norm(res.data - actual.data) < 1e-5)
def test_backward_batch():
i = torch.LongTensor([[0, 0, 0, 1, 1, 1],
[0, 1, 1, 0, 1, 1],
[2, 0, 2, 2, 0, 2]])
v = torch.FloatTensor([3, 4, 5, 6, 7, 8])
sparse = torch.sparse.FloatTensor(i, v, torch.Size([2, 2, 3]))
dense = Variable(torch.randn(2, 3, 4), requires_grad=True)
dense_copy = Variable(dense.data.clone(), requires_grad=True)
grad_output = torch.randn(2, 2, 4)
res = gpytorch.dsmm(Variable(sparse), dense)
res.backward(grad_output)
actual = torch.matmul(Variable(sparse.to_dense()), dense_copy)
actual.backward(grad_output)
assert(torch.norm(dense.grad.data - dense_copy.grad.data) < 1e-5)
def _derivative_quadratic_form_factory(self, lhs, rhs):
def closure(left_factor, right_factor):
left_grad = left_factor.transpose(-1, -2).matmul(right_factor.matmul(rhs.transpose(-1, -2)))
right_grad = lhs.transpose(-1, -2).matmul(left_factor.transpose(-1, -2)).matmul(right_factor)
return left_grad, right_grad
return closure
def evaluate(self):
return torch.matmul(self.lhs, self.rhs)
def _matmul_closure_factory(self, tensor):
def closure(rhs_tensor):
return torch.matmul(tensor, rhs_tensor)
return closure
def diag(self):
"""
Gets the diagonal of the Kronecker Product matrix wrapped by this object.
"""
if len(self.J_lefts[0]) != len(self.J_rights[0]):
raise RuntimeError('diag not supported for non-square interpolated Toeplitz matrices.')
d, n_data, n_interp = self.J_lefts.size()
n_grid = len(self.columns[0])
left_interps_values = self.C_lefts.unsqueeze(3)
right_interps_values = self.C_rights.unsqueeze(2)
interps_values = torch.matmul(left_interps_values, right_interps_values)
left_interps_indices = self.J_lefts.unsqueeze(3).expand(d, n_data, n_interp, n_interp)
right_interps_indices = self.J_rights.unsqueeze(2).expand(d, n_data, n_interp, n_interp)
toeplitz_indices = (left_interps_indices - right_interps_indices).fmod(n_grid).abs().long()
toeplitz_vals = Variable(self.columns.data.new(d, n_data * n_interp * n_interp).zero_())
mask = self.columns.data.new(d, n_data * n_interp * n_interp).zero_()
for i in range(d):
mask[i] += torch.ones(n_data * n_interp * n_interp)
temp = self.columns.index_select(1, Variable(toeplitz_indices.view(d, -1)[i]))
toeplitz_vals += Variable(mask) * temp.view(toeplitz_indices.size())
mask[i] -= torch.ones(n_data * n_interp * n_interp)
diag = (Variable(interps_values) * toeplitz_vals).sum(3).sum(2)
diag = diag.prod(0)
if self.added_diag is not None:
diag += self.added_diag
return diag
def forward(self, input_d, input_e, mask_d=None, mask_e=None):
'''
Args:
input_d: Tensor
the decoder input tensor with shape = [batch, length_decoder, input_size]
input_e: Tensor
the child input tensor with shape = [batch, length_encoder, input_size]
mask_d: Tensor or None
the mask tensor for decoder with shape = [batch, length_decoder]
mask_e: Tensor or None
the mask tensor for encoder with shape = [batch, length_encoder]
Returns: Tensor
the energy tensor with shape = [batch, num_label, length, length]
'''
assert input_d.size(0) == input_e.size(0), 'batch sizes of encoder and decoder are requires to be equal.'
batch, length_decoder, _ = input_d.size()
_, length_encoder, _ = input_e.size()
# compute decoder part: [batch, length_decoder, input_size_decoder] * [input_size_decoder, hidden_size]
# the output shape is [batch, length_decoder, hidden_size]
# then --> [batch, 1, length_decoder, hidden_size]
out_d = torch.matmul(input_d, self.W_d).unsqueeze(1)
# compute decoder part: [batch, length_encoder, input_size_encoder] * [input_size_encoder, hidden_size]
# the output shape is [batch, length_encoder, hidden_size]
# then --> [batch, length_encoder, 1, hidden_size]
out_e = torch.matmul(input_e, self.W_e).unsqueeze(2)
# add them together [batch, length_encoder, length_decoder, hidden_size]
out = F.tanh(out_d + out_e + self.b)
# product with v
# [batch, length_encoder, length_decoder, hidden_size] * [hidden, num_label]
# [batch, length_encoder, length_decoder, num_labels]
# then --> [batch, num_labels, length_decoder, length_encoder]
return torch.matmul(out, self.v).transpose(1, 3)
def matmul(self, other):
r"""Matrix product of two tensors.
See :func:`torch.matmul`."""
return torch.matmul(self, other)
def __matmul__(self, other):
if not torch.is_tensor(other):
return NotImplemented
return self.matmul(other)