def __call__(self, u_t, a, b, scope=None):
"""
:param u_t: [N, M, d]
:param a: [N, M. 1]
:param b: [N, M. 1]
:param mask: [N, M]
:return:
"""
N, M, d = self.batch_size, self.mem_size, self.hidden_size
L, sL = self.L, self.sL
with tf.name_scope(scope or self.__class__.__name__):
L = tf.tile(tf.expand_dims(L, 0), [N, 1, 1])
sL = tf.tile(tf.expand_dims(sL, 0), [N, 1, 1])
logb = tf.log(b + 1e-9)
logb = tf.concat(1, [tf.zeros([N, 1, 1]), tf.slice(logb, [0, 1, 0], [-1, -1, -1])])
left = L * tf.exp(tf.batch_matmul(L, logb * sL)) # [N, M, M]
right = a * u_t # [N, M, d]
u = tf.batch_matmul(left, right) # [N, M, d]
return u
评论列表
文章目录