def __call__(self, inputs, state, scope=None):
"""
:param inputs: [N, d + JQ + JQ * d]
:param state: [N, d]
:param scope:
:return:
"""
with tf.variable_scope(scope or self.__class__.__name__):
c_prev, h_prev = state
x = tf.slice(inputs, [0, 0], [-1, self._input_size])
q_mask = tf.slice(inputs, [0, self._input_size], [-1, self._q_len]) # [N, JQ]
qs = tf.slice(inputs, [0, self._input_size + self._q_len], [-1, -1])
qs = tf.reshape(qs, [-1, self._q_len, self._input_size]) # [N, JQ, d]
x_tiled = tf.tile(tf.expand_dims(x, 1), [1, self._q_len, 1]) # [N, JQ, d]
h_prev_tiled = tf.tile(tf.expand_dims(h_prev, 1), [1, self._q_len, 1]) # [N, JQ, d]
f = tf.tanh(linear([qs, x_tiled, h_prev_tiled], self._input_size, True, scope='f')) # [N, JQ, d]
a = tf.nn.softmax(exp_mask(linear(f, 1, True, squeeze=True, scope='a'), q_mask)) # [N, JQ]
q = tf.reduce_sum(qs * tf.expand_dims(a, -1), 1)
z = tf.concat(1, [x, q]) # [N, 2d]
return self._cell(z, state)
评论列表
文章目录