memory_module.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:chainer_frmqn 作者: okdshin 项目源码 文件源码
def read(self, h):
        #M_key = F.swapaxes(F.stack(self.key_buff, axis=0), axis1=0, axis2=1) # (B, M, m)
        M_key = F.stack(self.key_buff, axis=1) # (B, M, m)

        self.p = F.softmax(F.reshape(F.batch_matmul(M_key, h, transa=False, transb=False), (h.shape[0], M_key.shape[1]))) # (B, M)
        #p = F.reshape(p, (h.shape[0], 1, M_key.shape[1])) # (B, 1, M)
        #print("p", p.shape)
        #M_val = F.swapaxes(F.stack(self.val_buff, axis=0), axis1=0, axis2=1) # (B, M, m)
        M_val = F.stack(self.val_buff, axis=1) # (B, M, m)
        #print("M_val", M_val.shape)
        o = F.batch_matmul(self.p, M_val, transa=True, transb=False) # (B, 1, m)
        o = F.reshape(o, (o.shape[0], o.shape[2])) # (B, m)
        #print("o", o.shape)
        return o, self.p
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号