def _get_weight_vector(self, M, w_tm1, k, beta, g, s, gamma):
# M = tf.Print(M, [M, w_tm1, k], message='get weights beg1: ')
# M = tf.Print(M, [beta, g, s, gamma], message='get weights beg2: ')
# Content adressing, see Chapter 3.3.1:
num = beta * _cosine_distance(M, k)
w_c = K.softmax(num) # It turns out that equation (5) is just softmax.
# Location adressing, see Chapter 3.3.2:
# Equation 7:
w_g = (g * w_c) + (1-g)*w_tm1
# C_s is the circular convolution
#C_w = K.sum((self.C[None, :, :, :] * w_g[:, None, None, :]),axis=3)
# Equation 8:
# TODO: Explain
C_s = K.sum(K.repeat_elements(self.C[None, :, :, :], self.batch_size, axis=0) * s[:,:,None,None], axis=1)
w_tilda = K.batch_dot(C_s, w_g)
# Equation 9:
w_out = _renorm(w_tilda ** gamma)
return w_out
评论列表
文章目录