def __call__(self, inputs, state, scope=None):
with vs.variable_scope(scope or type(self).__name__): # "SCRNNCell"
if self._state_is_tuple:
s, h = state
else:
s, h = array_ops.split(1, 2, state)
new_s = tf.nn.rnn_cell._linear([(1 - self._alpha) * inputs, self._alpha * s], self._num_units, True, scope="SlowLinear")
new_h = tanh(tf.nn.rnn_cell._linear([inputs, new_s, h], self._num_units, True, scope="FastLinear"))
if self._state_is_tuple:
new_state = tf.nn.rnn_cell.LSTMStateTuple(new_s, new_h)
else:
new_state = array_ops.concat(1, [new_s, new_h])
return new_h, new_state
评论列表
文章目录