attention.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:keita 作者: iwasaki-kenta 项目源码 文件源码
def forward(self, last_state, states, mask=None):
        sequence_length, batch_size, hidden_dim = states.size()

        last_state = last_state.unsqueeze(0).expand(sequence_length, batch_size, last_state.size(1))
        if self.mode == "dot":
            energies = last_state * states
            energies = energies.sum(dim=2).squeeze()
        elif self.mode == "general":
            expanded_projection = self.projection.expand(sequence_length, *self.projection.size())
            energies = last_state * states.bmm(expanded_projection)
            energies = energies.sum(dim=2).squeeze()
        elif self.mode == "concat":
            expanded_reduction = self.reduction.expand(sequence_length, *self.reduction.size())
            expanded_projection = self.projection.expand(sequence_length, *self.projection.size())
            energies = F.tanh(torch.cat([last_state, states], dim=2).bmm(expanded_reduction))
            energies = energies.bmm(expanded_projection).squeeze()

        if type(mask) == torch.autograd.Variable:
            energies = energies + ((mask == 0).float() * -10000)
        attention_weights = F.softmax(energies)

        return attention_weights
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号