def _call_cell(self, inputs, initial_cell_state, initial_output, dtype,
sequence_length):
"""Run this LSTM on inputs, starting from the given state.
Args:
inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
initial_cell_state: initial value for cell state, shape `[batch_size,
self._num_units]`
initial_output: initial value of cell output, shape `[batch_size,
self._num_units]`
dtype: The data type for the initial state and expected output.
sequence_length: Specifies the length of each sequence in inputs. An
`int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
time_len)` or None.
Returns:
A pair containing:
- Cell state (cs): A `3-D` tensor of shape `[time_len, batch_size,
output_size]`
- Output (h): A `3-D` tensor of shape `[time_len, batch_size,
output_size]`
"""
inputs_shape = inputs.get_shape().with_rank(3)
time_len = inputs_shape[0].value
if time_len is None:
time_len = array_ops.shape(inputs)[0]
input_size = inputs_shape[2].value
w = vs.get_variable(
"W_0", [input_size + self._num_units, self._num_units * 4], dtype=dtype)
b = vs.get_variable(
"B", [w.get_shape().with_rank(2)[1]],
initializer=init_ops.constant_initializer(0.0),
dtype=dtype)
if self._use_peephole:
wci = vs.get_variable("W_I_diag", [self._num_units], dtype=dtype)
wco = vs.get_variable("W_O_diag", [self._num_units], dtype=dtype)
wcf = vs.get_variable("W_F_diag", [self._num_units], dtype=dtype)
else:
wci = wco = wcf = array_ops.zeros([self._num_units], dtype=dtype)
if sequence_length is None:
max_seq_len = time_len
else:
max_seq_len = math_ops.to_int64(math_ops.reduce_max(sequence_length))
_, cs, _, _, _, _, h = _lstm_ops_so.block_lstm(
seq_len_max=max_seq_len,
x=inputs,
cs_prev=initial_cell_state,
h_prev=initial_output,
w=w,
wci=wci,
wco=wco,
wcf=wcf,
b=b,
forget_bias=self._forget_bias,
cell_clip=self._cell_clip,
use_peephole=self._use_peephole)
return cs, h
评论列表
文章目录