PhasedLSTMCell.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:PLSTM 作者: Enny1991 项目源码 文件源码
def _get_sharded_variable(name, shape, dtype, num_shards):
    """Get a list of sharded variables with the given dtype."""
    if num_shards > shape[0]:
        raise ValueError("Too many shards: shape=%s, num_shards=%d" %
                         (shape, num_shards))
    unit_shard_size = int(math.floor(shape[0] / num_shards))
    remaining_rows = shape[0] - unit_shard_size * num_shards

    shards = []
    for i in range(num_shards):
        current_size = unit_shard_size
        if i < remaining_rows:
            current_size += 1
        shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:],
                                      dtype=dtype))
    return shards
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号