def __call__(self, inputs, state, scope=None):
with vs.variable_scope(scope or "goru_cell"):
U_init = init_ops.random_uniform_initializer(-0.01, 0.01)
b_init = init_ops.constant_initializer(2.)
mod_b_init = init_ops.constant_initializer(0.01)
U = vs.get_variable("U", [inputs.get_shape()[-1], self._hidden_size * 3], dtype=tf.float32, initializer = U_init)
Ux = math_ops.matmul(inputs, U)
U_cx, U_rx, U_gx = array_ops.split(Ux, 3, axis=1)
W_r = vs.get_variable("W_r", [self._hidden_size, self._hidden_size], dtype=tf.float32, initializer = U_init)
W_g = vs.get_variable("W_g", [self._hidden_size, self._hidden_size], dtype=tf.float32, initializer = U_init)
W_rh = math_ops.matmul(state, W_r)
W_gh = math_ops.matmul(state, W_g)
bias_r = vs.get_variable("bias_r", [self._hidden_size], dtype=tf.float32, initializer = b_init)
bias_g = vs.get_variable("bias_g", [self._hidden_size], dtype=tf.float32)
bias_c = vs.get_variable("bias_c", [self._hidden_size], dtype=tf.float32, initializer = mod_b_init)
r_tmp = U_rx + W_rh + bias_r
g_tmp = U_gx + W_gh + bias_g
r = math_ops.sigmoid(r_tmp)
g = math_ops.sigmoid(g_tmp)
Unitaryh = _eunn_loop(state, self._capacity, self.diag_vec, self.off_vec, self.diag, self._fft)
c = modrelu(math_ops.multiply(r, Unitaryh) + U_cx, bias_c, False)
new_state = math_ops.multiply(g, state) + math_ops.multiply(1 - g, c)
return new_state, new_state
评论列表
文章目录