def forward(self, inp, hidden):
outp = self.bilstm.forward(inp, hidden)[0]
size = outp.size() # [bsz, len, nhid]
compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2]
transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len]
transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len]
concatenated_inp = [transformed_inp for i in range(self.attention_hops)]
concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len]
hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit]
alphas = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop]
alphas = torch.transpose(alphas, 1, 2).contiguous() # [bsz, hop, len]
penalized_alphas = alphas + (
-10000 * (concatenated_inp == self.dictionary.word2idx['<pad>']).float())
# [bsz, hop, len] + [bsz, hop, len]
alphas = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len]
alphas = alphas.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len]
return torch.bmm(alphas, outp), alphas
models.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录