S2S_att.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:seq2seq_temporal_attention 作者: aistairc 项目源码 文件源码
def __call__(self, a_list, state, batch_size, xp):
        e_list = []
        sum_e = xp.zeros((batch_size, 1), dtype=xp.float32)
        for a in a_list:
            w = reshape(batch_matmul(state['h2'], a, transa=True), (batch_size, 1))
            w.data = xp.clip(w.data, -40, 40)
            e = exp(w)
            e_list.append(e)
            sum_e = sum_e + e

        context = xp.zeros((batch_size, self.hidden_size), dtype=xp.float32)

        for a, e in zip(a_list, e_list):
            e /= sum_e
            context = context + reshape(batch_matmul(a, e), (batch_size, self.hidden_size))
        return context, e_list, sum_e
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号