def batch_sim6(w, M, eps=1e-6):
"""
w: matrix with shape (batch, memory_elem)
M: tensor with shape (batch, memory_size, memory_elem)
eps: numerical stability parameter
"""
M = M[0] #only one true memory
#M = M.dimshuffle(1,0) # (memory_elem, memory_size)
def norm(A):
"""
Calculate the column norm of matrix A
A: matrix with shape (N, M)
return: vector with shape (N,)
"""
n, _ = theano.map(fn=lambda a: T.sqrt((a*a).sum()), sequences=[A])
return n
norm = T.outer(norm(w), norm(M)) #(batch, memory_size)
batch_sim = T.dot(w, M.T) / (norm + eps) #(batch, memory_size)
return batch_sim
评论列表
文章目录