tfutil.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:rltools 作者: sisl 项目源码 文件源码
def unflatten_into_tensors(flatparams_P, output_shapes, name=None):
    """
    Unflattens a vector produced by flatcat into a list of tensors of the specified shapes.
    """
    with tf.op_scope([flatparams_P], name, 'unflatten_into_tensors') as scope:
        outputs = []
        curr_pos = 0
        for shape in output_shapes:
            size = np.prod(shape).astype('int')
            flatval = flatparams_P[curr_pos:curr_pos + size]
            outputs.append(tf.reshape(flatval, shape))
            curr_pos += size
        assert curr_pos == flatparams_P.get_shape().num_elements(), "{} != {}".format(
            curr_pos, flatparams_P.get_shape().num_elements())
        return tf.tuple(outputs, name=scope)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号