attention.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:keita 作者: iwasaki-kenta 项目源码 文件源码
def forward(self, *hidden_states):
        if len(hidden_states) == 1:
            hidden_state = hidden_states[0]
            return F.softmax(F.tanh(self.projection(hidden_state))) * hidden_state
        elif len(hidden_states) == 2:
            left_hidden_state, right_hidden_state = hidden_states
            if self.mode == 0 or self.mode == 1:
                if self.mode == 0:
                    left_attention_weights = F.softmax(F.tanh(self.projection(left_hidden_state)))
                    right_attention_weights = F.softmax(F.tanh(self.projection(right_hidden_state)))
                elif self.mode == 1:
                    left_attention_weights = F.softmax(F.tanh(self.left_projection(left_hidden_state)))
                    right_attention_weights = F.softmax(F.tanh(self.right_projection(right_hidden_state)))

                return left_attention_weights * left_hidden_state, right_attention_weights * right_hidden_state
            elif self.mode == 2:
                hidden_state = torch.cat([left_hidden_state, right_hidden_state], dim=1)
                attention_weights = F.softmax(F.tanh(self.projection(hidden_state)))

                return attention_weights * left_hidden_state, attention_weights * right_hidden_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号