def forward(self, inp, hidden):
outp = self.bilstm.forward(inp, hidden)[0]
size = outp.size() # [bsz, len, nhid]
compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2]
transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len]
transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len]
concatenated_inp = [transformed_inp for i in range(self.attention_hops)]
concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len]
hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit]
alphas = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop]
alphas = torch.transpose(alphas, 1, 2).contiguous() # [bsz, hop, len]
penalized_alphas = alphas + (
-10000 * (concatenated_inp == self.dictionary.word2idx['<pad>']).float())
# [bsz, hop, len] + [bsz, hop, len]
alphas = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len]
alphas = alphas.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len]
return torch.bmm(alphas, outp), alphas
python类bmm()的实例源码
models.py 文件源码
项目:Structured-Self-Attentive-Sentence-Embedding
作者: ExplorerFreda
项目源码
文件源码
阅读 41
收藏 0
点赞 0
评论 0
def backward(self, grad_output):
batch1, batch2 = self.saved_tensors
grad_add_matrix = grad_batch1 = grad_batch2 = None
if self.needs_input_grad[0]:
grad_add_matrix = grad_output
if self.alpha != 1:
grad_add_matrix = grad_add_matrix.mul(self.alpha)
if any(self.needs_input_grad[1:]):
batch_grad_output = (grad_output
.unsqueeze(0)
.expand(batch1.size(0), batch1.size(1), batch2.size(2)))
if self.needs_input_grad[1]:
grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
if self.beta != 1:
grad_batch1 *= self.beta
if self.needs_input_grad[2]:
grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output)
if self.beta != 1:
grad_batch2 *= self.beta
return grad_add_matrix, grad_batch1, grad_batch2
def backward(self, grad_output):
batch1, batch2 = self.saved_tensors
grad_add_batch = grad_batch1 = grad_batch2 = None
if self.needs_input_grad[0]:
grad_add_batch = grad_output
if self.alpha != 1:
grad_add_batch = grad_add_batch.mul(self.alpha)
if self.needs_input_grad[1]:
grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2))
if self.beta != 1:
grad_batch1 *= self.beta
if self.needs_input_grad[2]:
grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output)
if self.beta != 1:
grad_batch2 *= self.beta
return grad_add_batch, grad_batch1, grad_batch2
def updateOutput(self, input):
assert len(input) == 2
a, b = input
assert a.ndimension() == 2 or a.ndimension() == 3
assert a.dim() == b.dim()
if a.ndimension() == 2:
if self.transA:
a = a.t()
if self.transB:
b = b.t()
self.output.resize_(a.size(0), b.size(1))
torch.mm(self.output, a, b)
else:
if self.transA:
a = a.transpose(2, 3)
if self.transB:
b = b.transpose(2, 3)
self.output.resize_(a.size(0), a.size(1), b.size(2))
torch.bmm(self.output, a, b)
return self.output
def updateOutput(self, input):
M, v = input
assert M.ndimension() == 2 or M.ndimension() == 3
if M.ndimension() == 2:
assert v.ndimension() == 1
if self.trans:
M = M.transpose(0, 1)
self.output.resize_(M.size(0))
torch.mv(self.output, M, v)
else:
assert v.ndimension() == 2
if self.trans:
M = M.transpose(1, 2)
self.output.resize_(M.size(0), M.size(1), 1)
torch.bmm(self.output, M, v.view(v.size(0), v.size(1), 1)).resize_(M.size(0), M.size(1))
return self.output
def forward(self, q, k, v, attn_mask):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
residual = q
bsz, len_q, d_model = q.size()
len_k, len_v = k.size(1), v.size(1)
def reshape(x):
"""[bsz, len, d_*] -> [n_head x (bsz*len) x d_*]"""
return x.repeat(n_head, 1, 1).view(n_head, -1, d_model)
q_s, k_s, v_s = map(reshape, [q, k, v])
q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k)
k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k)
v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v)
outputs = self.attention(q_s, k_s, v_s, attn_mask.repeat(n_head, 1, 1))
outputs = torch.cat(torch.split(outputs, bsz, dim=0), dim=-1).view(-1, n_head*d_v)
outputs = F.dropout(self.w_o(outputs), p=self.dropout).view(bsz, len_q, -1)
return self.lm(outputs + residual)
def forward(self, x, target_embedding, encoder_out):
residual = x
# attention
x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
x = self.bmm(x, encoder_out[0])
# softmax over last dim
sz = x.size()
x = F.softmax(x.view(sz[0] * sz[1], sz[2]))
x = x.view(sz)
attn_scores = x
x = self.bmm(x, encoder_out[1])
# scale attention output
s = encoder_out[1].size(1)
x = x * (s * math.sqrt(1.0 / s))
# project back
x = (self.out_projection(x) + residual) * math.sqrt(0.5)
return x, attn_scores
Modules.py 文件源码
项目:attention-is-all-you-need-pytorch
作者: jadore801120
项目源码
文件源码
阅读 30
收藏 0
点赞 0
评论 0
def forward(self, q, k, v, attn_mask=None):
attn = torch.bmm(q, k.transpose(1, 2)) / self.temper
if attn_mask is not None:
assert attn_mask.size() == attn.size(), \
'Attention mask shape {} mismatch ' \
'with Attention logit tensor shape ' \
'{}.'.format(attn_mask.size(), attn.size())
attn.data.masked_fill_(attn_mask, -float('inf'))
attn = self.softmax(attn)
attn = self.dropout(attn)
output = torch.bmm(attn, v)
return output, attn
def forward(self, inputs, context):
"""
inputs: batch x dim
context: batch x sourceL x dim
"""
targetT = self.linear_in(inputs).unsqueeze(2) # batch x dim x 1
# Get attention
attn = torch.bmm(context, targetT).squeeze(2) # batch x sourceL
if self.mask is not None:
attn.data.masked_fill_(self.mask, -_INF)
attn = self.sm(attn)
attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x sourceL
weightedContext = torch.bmm(attn3, context).squeeze(1) # batch x dim
contextCombined = torch.cat((weightedContext, inputs), 1)
contextOutput = self.tanh(self.linear_out(contextCombined))
return contextOutput, attn
def forward(self, x, target_embedding, encoder_out):
residual = x
# attention
x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5)
x = self.bmm(x, encoder_out[0])
# softmax over last dim
sz = x.size()
x = F.softmax(x.view(sz[0] * sz[1], sz[2]))
x = x.view(sz)
attn_scores = x
x = self.bmm(x, encoder_out[1])
# scale attention output
s = encoder_out[1].size(1)
x = x * (s * math.sqrt(1.0 / s))
# project back
x = (self.out_projection(x) + residual) * math.sqrt(0.5)
return x, attn_scores
def forward(self, input, context):
"""
input: batch x dim
context: batch x sourceL x dim
"""
targetT = self.linear_in(input).unsqueeze(2) # batch x dim x 1
# Get attention
attn = torch.bmm(context, targetT).squeeze(2) # batch x sourceL
if self.mask is not None:
attn.data.masked_fill_(self.mask, -float('inf'))
attn = self.sm(attn)
attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x sourceL
weightedContext = torch.bmm(attn3, context).squeeze(1) # batch x dim
contextCombined = torch.cat((weightedContext, input), 1)
contextOutput = self.tanh(self.linear_out(contextCombined))
return contextOutput, attn
def forward(self, input, context):
"""
input: batch x dim
context: batch x sourceL x dim
"""
targetT = self.linear_in(input).unsqueeze(2) # batch x dim x 1
# Get attention
attn = torch.bmm(context, targetT).squeeze(2) # batch x sourceL
if self.mask is not None:
attn.data.masked_fill_(self.mask, -float('inf'))
attn = self.sm(attn)
attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x sourceL
weightedContext = torch.bmm(attn3, context).squeeze(1) # batch x dim
contextCombined = torch.cat((weightedContext, input), 1)
contextOutput = self.tanh(self.linear_out(contextCombined))
return contextOutput, attn
def forward(self, x):
batchsize = x.size()[0]
trans = self.stn(x) # regressing the transforming parameters using STN
x = x.transpose(2,1) # bz x 2048 x 3
x = torch.bmm(x, trans) # (bz x 2048 x 3) x (bz x 3 x 3)
x = x.transpose(2,1) # bz x 3 x 2048
x = F.relu(self.bn1(self.conv1(x)))
pointfeat = x # bz x 64 x 2048
x = F.relu(self.bn2(self.conv2(x))) # bz x 128 x 2048
x = self.bn3(self.conv3(x)) # bz x 1024 x 2048
x = self.mp1(x)
x = x.view(-1, 1024) # bz x 1024
if self.global_feat: # using global feats for classification
return x, trans
else:
x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
return torch.cat([x, pointfeat], 1), trans
def backward(ctx, grad_output):
batch1, batch2 = ctx.saved_variables
grad_add_matrix = grad_batch1 = grad_batch2 = None
if ctx.needs_input_grad[0]:
grad_add_matrix = grad_output
if ctx.alpha != 1:
grad_add_matrix = grad_add_matrix.mul(ctx.alpha)
if any(ctx.needs_input_grad[1:]):
batch_grad_output = (grad_output
.unsqueeze(0)
.expand(batch1.size(0), batch1.size(1), batch2.size(2)))
if ctx.needs_input_grad[1]:
grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
if ctx.beta != 1:
grad_batch1 *= ctx.beta
if ctx.needs_input_grad[2]:
grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output)
if ctx.beta != 1:
grad_batch2 *= ctx.beta
return grad_add_matrix, grad_batch1, grad_batch2, None, None, None
def backward(ctx, grad_output):
batch1, batch2 = ctx.saved_variables
grad_add_batch = grad_batch1 = grad_batch2 = None
if ctx.needs_input_grad[0]:
grad_add_batch = grad_output
if ctx.alpha != 1:
grad_add_batch = grad_add_batch.mul(ctx.alpha)
if ctx.needs_input_grad[1]:
grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2))
if ctx.beta != 1:
grad_batch1 *= ctx.beta
if ctx.needs_input_grad[2]:
grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output)
if ctx.beta != 1:
grad_batch2 *= ctx.beta
return grad_add_batch, grad_batch1, grad_batch2, None, None, None
def updateOutput(self, input):
assert len(input) == 2
a, b = input
assert a.ndimension() == 2 or a.ndimension() == 3
assert a.dim() == b.dim()
if a.ndimension() == 2:
if self.transA:
a = a.t()
if self.transB:
b = b.t()
self.output.resize_(a.size(0), b.size(1))
torch.mm(a, b, out=self.output)
else:
if self.transA:
a = a.transpose(2, 3)
if self.transB:
b = b.transpose(2, 3)
self.output.resize_(a.size(0), a.size(1), b.size(2))
torch.bmm(a, b, out=self.output)
return self.output
def updateOutput(self, input):
M, v = input
assert M.ndimension() == 2 or M.ndimension() == 3
if M.ndimension() == 2:
assert v.ndimension() == 1
if self.trans:
M = M.transpose(0, 1)
self.output.resize_(M.size(0))
torch.mv(M, v, out=self.output)
else:
assert v.ndimension() == 2
if self.trans:
M = M.transpose(1, 2)
self.output.resize_(M.size(0), M.size(1), 1)
torch.bmm(M, v.view(v.size(0), v.size(1), 1), out=self.output).resize_(M.size(0), M.size(1))
return self.output
def _test_btrisolve(self, cast):
a = torch.FloatTensor((((1.3722, -0.9020),
(1.8849, 1.9169)),
((0.7187, -1.1695),
(-0.0139, 1.3572)),
((-1.6181, 0.7148),
(1.3728, 0.1319))))
b = torch.FloatTensor(((4.02, 6.19),
(-1.56, 4.00),
(9.81, -4.09)))
a, b = cast(a), cast(b)
info = cast(torch.IntTensor())
LU_data, pivots = a.btrifact(info=info)
self.assertEqual(info.abs().sum(), 0)
x = torch.btrisolve(b, LU_data, pivots)
b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
self.assertEqual(b_, b)
def forward(self, input1):
self.input1 = input1
output = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
for i in range(input1.size(0)):
self.batchgrid[i] = self.grid
if input1.is_cuda:
self.batchgrid = self.batchgrid.cuda()
output = output.cuda()
batchgrid_temp = self.batchgrid.view(-1, self.height*self.width, 3)
batchgrid_temp.contiguous()
input_temp = torch.transpose(input1, 1, 2)
input_temp.contiguous()
output_temp = torch.bmm(batchgrid_temp, input_temp)
output = output_temp.view(-1, self.height, self.width, 2)
output.contiguous()
return output
def forward(self, input, context):
"""Propogate input through the network.
input: batch x dim
context: batch x sourceL x dim
"""
target = self.linear_in(input).unsqueeze(2) # batch x dim x 1
# Get attention
attn = torch.bmm(context, target).squeeze(2) # batch x sourceL
attn = self.sm(attn)
attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x sourceL
weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
h_tilde = torch.cat((weighted_context, input), 1)
h_tilde = self.tanh(self.linear_out(h_tilde))
return h_tilde, attn
def backward(ctx, grad_output):
batch1, batch2 = ctx.saved_variables
grad_add_matrix = grad_batch1 = grad_batch2 = None
if ctx.needs_input_grad[0]:
grad_add_matrix = grad_output
if ctx.alpha != 1:
grad_add_matrix = grad_add_matrix.mul(ctx.alpha)
if any(ctx.needs_input_grad[1:]):
batch_grad_output = (grad_output
.unsqueeze(0)
.expand(batch1.size(0), batch1.size(1), batch2.size(2)))
if ctx.needs_input_grad[1]:
grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
if ctx.beta != 1:
grad_batch1 *= ctx.beta
if ctx.needs_input_grad[2]:
grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output)
if ctx.beta != 1:
grad_batch2 *= ctx.beta
return grad_add_matrix, grad_batch1, grad_batch2, None, None, None
def backward(ctx, grad_output):
batch1, batch2 = ctx.saved_variables
grad_add_batch = grad_batch1 = grad_batch2 = None
if ctx.needs_input_grad[0]:
grad_add_batch = grad_output
if ctx.alpha != 1:
grad_add_batch = grad_add_batch.mul(ctx.alpha)
if ctx.needs_input_grad[1]:
grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2))
if ctx.beta != 1:
grad_batch1 *= ctx.beta
if ctx.needs_input_grad[2]:
grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output)
if ctx.beta != 1:
grad_batch2 *= ctx.beta
return grad_add_batch, grad_batch1, grad_batch2, None, None, None
def updateOutput(self, input):
assert len(input) == 2
a, b = input
assert a.ndimension() == 2 or a.ndimension() == 3
assert a.dim() == b.dim()
if a.ndimension() == 2:
if self.transA:
a = a.t()
if self.transB:
b = b.t()
self.output.resize_(a.size(0), b.size(1))
torch.mm(a, b, out=self.output)
else:
if self.transA:
a = a.transpose(1, 2)
if self.transB:
b = b.transpose(1, 2)
self.output.resize_(a.size(0), a.size(1), b.size(2))
torch.bmm(a, b, out=self.output)
return self.output
def updateOutput(self, input):
M, v = input
assert M.ndimension() == 2 or M.ndimension() == 3
if M.ndimension() == 2:
assert v.ndimension() == 1
if self.trans:
M = M.transpose(0, 1)
self.output.resize_(M.size(0))
torch.mv(M, v, out=self.output)
else:
assert v.ndimension() == 2
if self.trans:
M = M.transpose(1, 2)
self.output.resize_(M.size(0), M.size(1), 1)
torch.bmm(M, v.view(v.size(0), v.size(1), 1), out=self.output).resize_(M.size(0), M.size(1))
return self.output
def _test_btrisolve(self, cast):
a = torch.FloatTensor((((1.3722, -0.9020),
(1.8849, 1.9169)),
((0.7187, -1.1695),
(-0.0139, 1.3572)),
((-1.6181, 0.7148),
(1.3728, 0.1319))))
b = torch.FloatTensor(((4.02, 6.19),
(-1.56, 4.00),
(9.81, -4.09)))
a, b = cast(a), cast(b)
info = cast(torch.IntTensor())
LU_data, pivots = a.btrifact(info=info)
self.assertEqual(info.abs().sum(), 0)
x = torch.btrisolve(b, LU_data, pivots)
b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
self.assertEqual(b_, b)
def forward(self, x):
batchsize = x.size()[0]
trans = self.stn(x)
x = x.transpose(2,1)
x = torch.bmm(x, trans)
x = x.transpose(2,1)
x = F.relu(self.bn1(self.conv1(x)))
pointfeat = x
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x = self.mp1(x)
x = x.view(-1, 1024)
if self.global_feat:
return x, trans
else:
x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
return torch.cat([x, pointfeat], 1), trans
def calc_score(self, att_query, att_keys):
"""
att_query is: b x t_q x n
att_keys is b x t_k x n
return b x t_q x t_k scores
"""
b, t_k, n = list(att_keys.size())
t_q = att_query.size(1)
if self.mode == 'bahdanau':
att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
sum_qk = att_query + att_keys
sum_qk = sum_qk.view(b * t_k * t_q, n)
out = self.linear_att(F.tanh(sum_qk)).view(b, t_q, t_k)
elif self.mode == 'dot_prod':
out = torch.bmm(att_query, att_keys.transpose(1, 2))
if self.normalize:
out.div_(n ** 0.5)
return out
def forward(self, v, z):
'''
:param v: batch_size (B) x latent_size (L)
:param z: batch_size (B) x latent_size (L)
:return: z_new = z - 2* v v_T / norm(v,2) * z
'''
# v * v_T
vvT = torch.bmm( v.unsqueeze(2), v.unsqueeze(1) ) # v * v_T : batch_dot( B x L x 1 * B x 1 x L ) = B x L x L
# v * v_T * z
vvTz = torch.bmm( vvT, z.unsqueeze(2) ).squeeze(2) # A * z : batchdot( B x L x L * B x L x 1 ).squeeze(2) = (B x L x 1).squeeze(2) = B x L
# calculate norm ||v||^2
norm_sq = torch.sum( v * v, 1 ) # calculate norm-2 for each row : B x 1
norm_sq = norm_sq.expand( norm_sq.size(0), v.size(1) ) # expand sizes : B x L
# calculate new z
z_new = z - 2 * vvTz / norm_sq # z - 2 * v * v_T * z / norm2(v)
return z_new
def forward(self, L, z):
'''
:param L: batch_size (B) x latent_size^2 (L^2)
:param z: batch_size (B) x latent_size (L)
:return: z_new = L*z
'''
# L->tril(L)
L_matrix = L.view( -1, self.args.z1_size, self.args.z1_size ) # resize to get B x L x L
LTmask = torch.tril( torch.ones(self.args.z1_size, self.args.z1_size), k=-1 ) # lower-triangular mask matrix (1s in lower triangular part)
I = Variable( torch.eye(self.args.z1_size, self.args.z1_size).expand(L_matrix.size(0), self.args.z1_size, self.args.z1_size) )
if self.args.cuda:
LTmask = LTmask.cuda()
I = I.cuda()
LTmask = Variable(LTmask)
LTmask = LTmask.unsqueeze(0).expand( L_matrix.size(0), self.args.z1_size, self.args.z1_size ) # 1 x L x L -> B x L x L
LT = torch.mul( L_matrix, LTmask ) + I # here we get a batch of lower-triangular matrices with ones on diagonal
# z_new = L * z
z_new = torch.bmm( LT , z.unsqueeze(2) ).squeeze(2) # B x L x L * B x L x 1 -> B x L
return z_new
def forward(ctx, theta, size):
assert type(size) == torch.Size
N, C, H, W = size
ctx.size = size
if theta.is_cuda:
ctx.is_cuda = True
AffineGridGenerator._enforce_cudnn(theta)
grid = theta.new(N, H, W, 2)
theta = theta.contiguous()
torch._C._cudnn_affine_grid_generator_forward(theta, grid, N, C, H, W)
else:
ctx.is_cuda = False
base_grid = theta.new(N, H, W, 3)
linear_points = torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1])
base_grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(base_grid[:, :, :, 0])
linear_points = torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1])
base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(base_grid[:, :, :, 1])
base_grid[:, :, :, 2] = 1
ctx.base_grid = base_grid
grid = torch.bmm(base_grid.view(N, H * W, 3), theta.transpose(1, 2))
grid = grid.view(N, H, W, 2)
return grid