state_util.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:magenta 作者: tensorflow 项目源码 文件源码
def batch(states, batch_size=None):
  """Combines a collection of state structures into a batch, padding if needed.

  Args:
    states: A collection of individual nested state structures.
    batch_size: The desired final batch size. If the nested state structure
        that results from combining the states is smaller than this, it will be
        padded with zeros.
  Returns:
    A single state structure that results from stacking the structures in
    `states`, with padding if needed.

  Raises:
    ValueError: If the number of input states is larger than `batch_size`.
  """
  if batch_size and len(states) > batch_size:
    raise ValueError('Combined state is larger than the requested batch size')

  def stack_and_pad(*states):
    stacked = np.stack(states)
    if batch_size:
      stacked.resize([batch_size] + list(stacked.shape)[1:])
    return stacked
  return tf_nest.map_structure(stack_and_pad, *states)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号