def __call__(self, u_t, a, b, scope=None):
"""
:param u_t: [N, M, d]
:param a: [N, M. d]
:param b: [N, M. d]
:param mask: [N, M]
:return:
"""
N, M, d = self.batch_size, self.mem_size, self.hidden_size
L, sL = self.L, self.sL
with tf.name_scope(scope or self.__class__.__name__):
L = tf.tile(tf.expand_dims(tf.expand_dims(L, 0), 0), [N, d, 1, 1])
sL = tf.tile(tf.expand_dims(tf.expand_dims(sL, 0), 0), [N, d, 1, 1])
logb = tf.log(b + 1e-9) # [N, M, d]
logb = tf.concat(1, [tf.zeros([N, 1, d]), tf.slice(logb, [0, 1, 0], [-1, -1, -1])]) # [N, M, d]
logb = tf.expand_dims(tf.transpose(logb, [0, 2, 1]), -1) # [N, d, M, 1]
left = L * tf.exp(tf.batch_matmul(L, logb * sL)) # [N, d, M, M]
right = a * u_t # [N, M, d]
right = tf.expand_dims(tf.transpose(right, [0, 2, 1]), -1) # [N, d, M, 1]
u = tf.batch_matmul(left, right) # [N, d, M, 1]
u = tf.transpose(tf.squeeze(u, [3]), [0, 2, 1]) # [N, M, d]
return u
评论列表
文章目录