matching_q.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:KerasRL 作者: aejax 项目源码 文件源码
def batch_sim6(w, M, eps=1e-6):
    """
    w: matrix with shape (batch, memory_elem)
    M: tensor with shape (batch, memory_size, memory_elem)
    eps: numerical stability parameter
    """
    M = M[0] #only one true memory
    #M = M.dimshuffle(1,0) # (memory_elem, memory_size)
    def norm(A):
        """
        Calculate the column norm of matrix A
        A: matrix with shape (N, M)
        return: vector with shape (N,)
        """
        n, _ = theano.map(fn=lambda a: T.sqrt((a*a).sum()), sequences=[A])
        return n

    norm = T.outer(norm(w), norm(M)) #(batch, memory_size)
    batch_sim = T.dot(w, M.T) / (norm + eps) #(batch, memory_size)
    return batch_sim
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号