matching_q.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:KerasRL 作者: aejax 项目源码 文件源码
def batch_sim5(w, M, eps=1e-6):
    """
    w: matrix with shape (batch, memory_elem)
    M: tensor with shape (batch, memory_size, memory_elem)
    eps: numerical stability parameter
    """
    M = M[0] # (memory_size, memory_elem)
    def batch_cos_sim(m, w, eps=eps):
        """
        Takes two vectors and calculates the scalar cosine similarity.

        m: vector with shape (memory_elem,)
        w: vector with shape (batch, memory_elem)
        returns: scalar
        """
        sim = T.dot(m,w.T) / T.sqrt((m*m).sum() * (w*w).sum(1) + eps)
        return sim #(batch,)

    sim, _ = theano.map(fn=batch_cos_sim, sequences=[M], non_sequences=[w])
    sim = sim.dimshuffle(1,0) # (batch, memory_size)
    return sim
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号