layers.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:fg-gating 作者: kimiyoung 项目源码 文件源码
def get_output_for(self, inputs, **kwargs):
        p_gru, q_gru, q_mask, feature = tuple(inputs)
        time_p = p_gru.shape[1]
        time_q = q_gru.shape[1]
        p_gru_re = p_gru.dimshuffle(0, 1, 'x', 2) # (batch, time_p, 1, units)
        q_gru_re = q_gru.dimshuffle(0, 'x', 1, 2) # (batch, 1, time_q, units)
        gru_merge = T.tanh(p_gru_re * q_gru_re).reshape((-1, time_q, self.units)) # (batch * time_p, time_q, units)

        att = T.dot(gru_merge, self.v1).reshape((-1, time_p, time_q)) # (batch, time_p, time_q)
        att_q = T.dot(q_gru, self.v2).squeeze() # (batch, time_q)
        att = att + att_q.dimshuffle(0, 'x', 1) + feature # (batch, time_p, time_q)
        att = T.nnet.softmax(att.reshape((-1, time_q))) # (batch * time_p, time_q)
        att = att.reshape((-1, time_p, time_q)) * q_mask.dimshuffle(0, 'x', 1) # (batch, time_p, time_q)
        att = att / (att.sum(axis = 2, keepdims = True) + 1e-8) # (batch, time_p, time_q)
        att = att.reshape((-1, time_q))

        output = T.batched_dot(att, gru_merge) # (batch * time_p, units)
        output = output.reshape((-1, time_p, self.units))
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号