def unflatten_into_tensors(flatparams_P, output_shapes, name=None):
"""
Unflattens a vector produced by flatcat into a list of tensors of the specified shapes.
"""
with tf.op_scope([flatparams_P], name, 'unflatten_into_tensors') as scope:
outputs = []
curr_pos = 0
for shape in output_shapes:
size = np.prod(shape).astype('int')
flatval = flatparams_P[curr_pos:curr_pos + size]
outputs.append(tf.reshape(flatval, shape))
curr_pos += size
assert curr_pos == flatparams_P.get_shape().num_elements(), "{} != {}".format(
curr_pos, flatparams_P.get_shape().num_elements())
return tf.tuple(outputs, name=scope)
评论列表
文章目录