layers.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:fg-gating 作者: kimiyoung 项目源码 文件源码
def get_output_for(self, inputs, attention_only=False, **kwargs):

        # inputs[0]: B x N x D
        # inputs[1]: B x Q x D
        # inputs[2]: B x N x Q / B x Q x N
        # self.mask: B x Q

        if self.transpose: M = inputs[2].dimshuffle((0,2,1))
        else: M = inputs[2]
        alphas = T.nnet.softmax(T.reshape(M, (M.shape[0]*M.shape[1],M.shape[2])))
        alphas_r = T.reshape(alphas, (M.shape[0],M.shape[1],M.shape[2]))* \
                self.mask[:,np.newaxis,:] # B x N x Q
        alphas_r = alphas_r/alphas_r.sum(axis=2)[:,:,np.newaxis] # B x N x Q
        q_rep = T.batched_dot(alphas_r, inputs[1]) # B x N x D

        return eval(self.gating_fn)(inputs[0],q_rep)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号