def __call__(self, X, w_temp, m_temp):
# input dimensions
# X: (nb_samples, input_dim)
# w_temp: (nb_samples, memory_dim)
# m_temp: (nb_samples, memory_dim, memory_width) ::tensor_memory
key = dot(X, self.W_key, self.b_key) # (nb_samples, memory_width)
shift = self.softmax(
dot(X, self.W_shift, self.b_shift)) # (nb_samples, shift_width)
beta = self.softplus(dot(X, self.W_beta, self.b_beta))[:, None] # (nb_samples, x)
gamma = self.softplus(dot(X, self.W_gama, self.b_gama)) + 1. # (nb_samples,)
gamma = gamma[:, None] # (nb_samples, x)
g = self.sigmoid(dot(X, self.W_g, self.b_g))[:, None] # (nb_samples, x)
signal = [key, shift, beta, gamma, g]
w_c = self.softmax(
beta * cosine_sim2d(key, m_temp)) # (nb_samples, memory_dim) //content-based addressing
w_g = g * w_c + (1 - g) * w_temp # (nb_samples, memory_dim) //history interpolation
w_s = shift_convolve2d(w_g, shift, self.shift_conv) # (nb_samples, memory_dim) //convolutional shift
w_p = w_s ** gamma # (nb_samples, memory_dim) //sharpening
w_t = w_p / T.sum(w_p, axis=1)[:, None] # (nb_samples, memory_dim)
return w_t
评论列表
文章目录