def _attention_forward(self, Y, mask_Y, h, r_tm1=None, index=None):
'''
Computes the Attention Weights over Y using h (and r_tm1 if given)
Returns an attention weighted representation of Y, and the alphas
inputs:
Y : T x batch x n_dim
mask_Y : T x batch
h : batch x n_dim
r_tm1 : batch x n_dim
index : int : The timestep
params:
W_y : n_dim x n_dim
W_h : n_dim x n_dim
W_r : n_dim x n_dim
W_alpha : n_dim x 1
outputs :
r = batch x n_dim
alpha : batch x T
'''
Y = Y.transpose(1, 0) # batch x T x n_dim
mask_Y = mask_Y.transpose(1, 0) # batch x T
Wy = torch.bmm(Y, self.W_y.unsqueeze(0).expand(Y.size(0), *self.W_y.size())) # batch x T x n_dim
Wh = torch.mm(h, self.W_h) # batch x n_dim
if r_tm1 is not None:
W_r_tm1 = self.batch_norm_r_r(torch.mm(r_tm1, self.W_r), index) if hasattr(self, 'batch_norm_r_r') else torch.mm(r_tm1, self.W_r)
Wh = self.batch_norm_h_r(Wh, index) if hasattr(self, 'batch_norm_h_r') else Wh
Wh += W_r_tm1
M = torch.tanh(Wy + Wh.unsqueeze(1).expand(Wh.size(0), Y.size(1), Wh.size(1))) # batch x T x n_dim
alpha = torch.bmm(M, self.W_alpha.unsqueeze(0).expand(Y.size(0), *self.W_alpha.size())).squeeze(-1) # batch x T
alpha = alpha + (-1000.0 * (1. - mask_Y)) # To ensure probability mass doesn't fall on non tokens
alpha = F.softmax(alpha)
if r_tm1 is not None:
r = torch.bmm(alpha.unsqueeze(1), Y).squeeze(1) + F.tanh(torch.mm(r_tm1, self.W_t)) # batch x n_dim
else:
r = torch.bmm(alpha.unsqueeze(1), Y).squeeze(1) # batch x n_dim
return r, alpha
rte_model.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录