def __call__(self, inputs, state, scope=None):
with tf.variable_scope(scope or "SHCell"):
a_size = 1 if self._scalar else self._state_size
h, u = tf.split(1, 2, inputs)
if self._logit_func == 'mul_linear':
args = [h * u, state * u]
a = tf.nn.sigmoid(linear(args, a_size, True))
elif self._logit_func == 'linear':
args = [h, u, state]
a = tf.nn.sigmoid(linear(args, a_size, True))
elif self._logit_func == 'tri_linear':
args = [h, u, state, h * u, state * u]
a = tf.nn.sigmoid(linear(args, a_size, True))
elif self._logit_func == 'double':
args = [h, u, state]
a = tf.nn.sigmoid(linear(tf.tanh(linear(args, a_size, True)), self._state_size, True))
else:
raise Exception()
new_state = a * state + (1 - a) * h
outputs = state
return outputs, new_state
评论列表
文章目录