def __call__(self, x, h_prev, scope=None):
"""GRU cell."""
with vs.variable_scope(scope or type(self).__name__):
input_size = x.get_shape().with_rank(2)[1]
# Check if the input size exist.
if input_size is None:
raise ValueError("Expecting input_size to be set.")
# Check cell_size == state_size from h_prev.
cell_size = h_prev.get_shape().with_rank(2)[1]
if cell_size != self._cell_size:
raise ValueError("Shape of h_prev[1] incorrect: cell_size %i vs %s" %
(self._cell_size, cell_size))
if cell_size is None:
raise ValueError("cell_size from `h_prev` should not be None.")
w_ru = vs.get_variable("w_ru", [input_size + self._cell_size,
self._cell_size * 2])
b_ru = vs.get_variable(
"b_ru", [self._cell_size * 2],
initializer=init_ops.constant_initializer(1.0))
w_c = vs.get_variable("w_c",
[input_size + self._cell_size, self._cell_size])
b_c = vs.get_variable(
"b_c", [self._cell_size],
initializer=init_ops.constant_initializer(0.0))
_gru_block_cell = _gru_ops_so.gru_block_cell # pylint: disable=invalid-name
_, _, _, new_h = _gru_block_cell(
x=x, h_prev=h_prev, w_ru=w_ru, w_c=w_c, b_ru=b_ru, b_c=b_c)
return new_h, new_h
评论列表
文章目录