matching_q.py 文件源码

python
阅读 87 收藏 0 点赞 0 评论 0

项目:KerasRL 作者: aejax 项目源码 文件源码
def batch_sim(w, M, eps=1e-6):
    """
    w: matrix with shape (batch, memory_elem)
    M: tensor with shape (batch, memory_size, memory_elem)
    eps: numerical stability parameter
    """
    M = M.dimshuffle(1,0,2) # (N, batch, M)
    def cos_sim(u, v, eps=eps):
        """
        Takes two vectors and calculates the scalar cosine similarity.

        u: vector with shape (memory_elem,)
        v: vector with shape (memory_elem,)
        returns: scalar
        """
        sim = T.dot(u,v) / T.sqrt((u*u).sum() * (v*v).sum() + eps)
        return sim

    def batch_cos_sim(m_i, w):
        """
        Takes two matrices and calculates the scalar cosine similarity
        of their columns.

        m_i: matrix with shape (batch, memory_elem)
        w: matrix with shape (batch, memory_elem)
        returns: vector with shape (batch,)
        """
        sim, _ = theano.map(fn=cos_sim, sequences=[w, m_i])
        return sim

    sim, _ = theano.map(fn=batch_cos_sim, sequences=[M], non_sequences=[w])
    sim = sim.dimshuffle(1,0) # (batch, memory_size)
    return sim
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号