def get_state_variables_for_batch(state_variables, batch_size):
"""Returns a subset of the state variables.
This function takes the state variables returned by get_state_variables() and returns a subset
for an actual forward-propagation run. Specifically, it clips each of the state variables to
the given batch size.
Before this call, each variable's first dimension has length
max_batch_size but when the input has a lower batch size, the LSTM should also only update the
state variables for the used batches.
See get_state_variables() for more info.
Args:
state_variables (tuple[tf.contrib.rnn.LSTMStateTuple]): The LSTM's state variables.
batch_size (tf.Tensor): An 0-dimensional tensor containing the batch size tensor in the
computational graph.
Returns:
tuple[tf.contrib.rnn.LSTMStateTuple]: A new tuple of state variables clipped to the given
batch size.
"""
# Return a tuple of LSTMStateTuples but with only the first batch_size rows for each variable
# in the tuples.
result = []
for state_c, state_h in state_variables:
result.append(tf.contrib.rnn.LSTMStateTuple(state_c[:batch_size], state_h[:batch_size]))
return tuple(result)
评论列表
文章目录