def batch(states, batch_size=None):
"""Combines a collection of state structures into a batch, padding if needed.
Args:
states: A collection of individual nested state structures.
batch_size: The desired final batch size. If the nested state structure
that results from combining the states is smaller than this, it will be
padded with zeros.
Returns:
A single state structure that results from stacking the structures in
`states`, with padding if needed.
Raises:
ValueError: If the number of input states is larger than `batch_size`.
"""
if batch_size and len(states) > batch_size:
raise ValueError('Combined state is larger than the requested batch size')
def stack_and_pad(*states):
stacked = np.stack(states)
if batch_size:
stacked.resize([batch_size] + list(stacked.shape)[1:])
return stacked
return tf_nest.map_structure(stack_and_pad, *states)
评论列表
文章目录