def call(self, inputs, state, scope=None):
with vs.variable_scope(scope or type(self).__name__): # "GruRcnCell"
with vs.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0.
w_zrw = self._conv(inputs, self._num_outputs*3, self._ih_filter_h_length, self._ih_filter_w_length,
self._ih_strides, self._ih_pandding, init_ops.truncated_normal_initializer(stddev=0.01), scope="WzrwConv")
u_zr = self._conv(state, self._num_outputs*2, self._hh_filter_h_length, self._hh_filter_w_length, [1, 1, 1, 1],
"SAME", init_ops.truncated_normal_initializer(stddev=0.01), scope="UzrConv")
w_z, w_r, w =tf.split(value=w_zrw, num_or_size_splits=3, axis=3, name="w_split")
u_z, u_r =tf.split(value=u_zr, num_or_size_splits=2, axis=3, name="u_split")
z_bias = tf.get_variable(
name="z_biases",
shape=[self._num_outputs],
initializer=init_ops.ones_initializer()
)
z_gate = math_ops.sigmoid(tf.nn.bias_add(w_z + u_z, z_bias))
r_bias = tf.get_variable(
name="r_biases",
shape=[self._num_outputs],
initializer=init_ops.ones_initializer())
r_gate = math_ops.sigmoid(tf.nn.bias_add(w_r + u_r, r_bias))
with vs.variable_scope("Candidate"):
# w = self._conv(inputs, self._num_outputs, self._ih_filter_h_length, self._ih_filter_w_length,
# self._ih_strides, self._ih_pandding, init_ops.truncated_normal_initializer(stddev=0.01), scope="WConv")
u = self._conv(r_gate * state, self._num_outputs, self._hh_filter_h_length, self._hh_filter_w_length,
[1, 1, 1, 1], "SAME", init_ops.truncated_normal_initializer(stddev=0.01), scope="UConv")
c_bias = tf.get_variable(
name="c_biases",
shape=[self._num_outputs],
initializer=init_ops.ones_initializer())
c = math_ops.tanh(tf.nn.bias_add(w + u, c_bias))
new_h = z_gate * state + (1 - z_gate) * c
return new_h, new_h
评论列表
文章目录