def dot_2d(k, M, b=None, g=None):
# k: (nb_samples, memory_width)
# M: (nb_samples, memory_dim, memory_width)
# norms of keys and memories
# k_norm = T.sqrt(T.sum(T.sqr(k), 1)) + 1e-5 # (nb_samples,)
# M_norm = T.sqrt(T.sum(T.sqr(M), 2)) + 1e-5 # (nb_samples, memory_dim,)
k = k[:, None, :] # (nb_samples, 1, memory_width)
value = k * M
if b is not None:
b = b[:, None, :]
value *= b # (nb_samples, memory_dim,)
if g is not None:
g = g[None, None, :]
value *= g
sim = T.sum(value, axis=2)
return sim
评论列表
文章目录