attention.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:keita 作者: iwasaki-kenta 项目源码 文件源码
def forward(self, last_state, states):
        if len(states.size()) == 2: states = states.unsqueeze(0)

        sequence_length, batch_size, state_dim = states.size()

        transformed_last_state = last_state @ self.projection
        transformed_last_state = transformed_last_state.expand(sequence_length, batch_size, self.encoder_dim)
        transformed_last_state = transformed_last_state.transpose(0, 1).contiguous()
        transformed_last_state = transformed_last_state.view(batch_size, -1)

        states = states.transpose(0, 1).contiguous()
        states = states.view(batch_size, -1)

        energies = transformed_last_state * states
        energies = energies.sum(dim=1)

        if self.encoder_dim is not None:
            attention_weights = torch.cat([torch.exp(energies[0]), F.softmax(energies[1:])], dim=0)
        else:
            attention_weights = F.softmax(energies)

        return attention_weights
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号