def __call__(self, inputs, state, scope=None):
"""
:param inputs: [N*B, I + B]
:param state: [N*B, d]
:param scope:
:return: [N*B, d]
"""
with tf.variable_scope(scope or self.__class__.__name__):
d = self.state_size
x = tf.slice(inputs, [0, 0], [-1, self._input_size]) # [N*B, I]
mask = tf.slice(inputs, [0, self._input_size], [-1, -1]) # [N*B, B]
B = tf.shape(mask)[1]
prev_state = tf.expand_dims(tf.reshape(state, [-1, B, d]), 1) # [N, B, d] -> [N, 1, B, d]
mask = tf.tile(tf.expand_dims(tf.reshape(mask, [-1, B, B]), -1), [1, 1, 1, d]) # [N, B, B, d]
# prev_state = self._reduce_func(tf.tile(prev_state, [1, B, 1, 1]), 2)
prev_state = self._reduce_func(exp_mask(prev_state, mask), 2) # [N, B, d]
prev_state = tf.reshape(prev_state, [-1, d]) # [N*B, d]
return self._cell(x, prev_state)
评论列表
文章目录