def __call__(self, inputs, state, scope=None):
with tf.variable_scope(scope or type(self).__name__):
c_prev, h_prev, update_prob_prev, cum_update_prob_prev = state
# Parameters of gates are concatenated into one multiply for efficiency.
concat = rnn_ops.linear([inputs, h_prev], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = tf.split(value=concat, num_or_size_splits=4, axis=1)
if self._layer_norm:
i = rnn_ops.layer_norm(i, name="i")
j = rnn_ops.layer_norm(j, name="j")
f = rnn_ops.layer_norm(f, name="f")
o = rnn_ops.layer_norm(o, name="o")
new_c_tilde = (c_prev * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * self._activation(j))
new_h_tilde = self._activation(new_c_tilde) * tf.sigmoid(o)
# Compute value for the update prob
with tf.variable_scope('state_update_prob'):
new_update_prob_tilde = rnn_ops.linear(new_c_tilde, 1, True, bias_start=self._update_bias)
new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde)
# Compute value for the update gate
cum_update_prob = cum_update_prob_prev + tf.minimum(update_prob_prev, 1. - cum_update_prob_prev)
update_gate = _binary_round(cum_update_prob)
# Apply update gate
new_c = update_gate * new_c_tilde + (1. - update_gate) * c_prev
new_h = update_gate * new_h_tilde + (1. - update_gate) * h_prev
new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev
new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob
new_state = SkipLSTMStateTuple(new_c, new_h, new_update_prob, new_cum_update_prob)
new_output = SkipLSTMOutputTuple(new_h, update_gate)
return new_output, new_state
skip_rnn_cells.py 文件源码
python
阅读 34
收藏 0
点赞 0
评论 0
评论列表
文章目录