layers.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:fg-gating 作者: kimiyoung 项目源码 文件源码
def get_output_for(self, inputs, **kwargs):

        # inputs[0]: B x N x D, doc
        # inputs[1]: B x Q x D, query
        # self.aggregator: B x N x C
        # self.pointer: B x 1
        # self.mask: B x N

        q = inputs[1][T.arange(inputs[1].shape[0]),self.pointer,:] # B x D
        p = T.batched_dot(inputs[0],q) # B x N
        pm = T.nnet.softmax(p)*self.mask # B x N
        pm = pm/pm.sum(axis=1)[:,np.newaxis] # B x N

        return T.batched_dot(pm, self.aggregator)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号