def __call__(self, inputs, state, scope=None):
gate_size = self._gate_size
with tf.variable_scope(scope or type(self).__name__): # "RSMCell"
with tf.name_scope("Split"): # Reset gate and update gate.
a = tf.slice(inputs, [0, 0], [-1, gate_size])
x, u, v_t = tf.split(1, 3, tf.slice(inputs, [0, gate_size], [-1, -1]))
o = tf.slice(state, [0, 0], [-1, 1])
h, v = tf.split(1, 2, tf.slice(state, [0, gate_size], [-1, -1]))
with tf.variable_scope("Main"):
r_raw = linear([x * u], 1, True, scope='r_raw', var_on_cpu=self._var_on_cpu,
initializer=self._initializer)
r = tf.sigmoid(r_raw, name='a')
new_o = a * r + (1 - a) * o
new_v = a * v_t + (1 - a) * v
g = r * v_t
new_h = a * g + (1 - a) * h
with tf.name_scope("Concat"):
new_state = tf.concat(1, [new_o, new_h, new_v])
outputs = tf.concat(1, [a, r, x, new_h, new_v, g])
return outputs, new_state
评论列表
文章目录