rnn_cell.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:DL-Benchmarks 作者: DL-Benchmarks 项目源码 文件源码
def _get_sharded_variable(name, shape, dtype, num_shards):
  """Get a list of sharded variables with the given dtype."""
  if num_shards > shape[0]:
    raise ValueError("Too many shards: shape=%s, num_shards=%d" %
                     (shape, num_shards))
  unit_shard_size = int(math.floor(shape[0] / num_shards))
  remaining_rows = shape[0] - unit_shard_size * num_shards

  shards = []
  for i in range(num_shards):
    current_size = unit_shard_size
    if i < remaining_rows:
      current_size += 1
    shards.append(vs.get_variable(name + "_%d" % i, [current_size, shape[1]],
                                  dtype=dtype))
  return shards
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号