def unpack_into_tensorarray(value, axis, size=None):
"""
unpacks a given tensor along a given axis into a TensorArray
Parameters:
----------
value: Tensor
the tensor to be unpacked
axis: int
the axis to unpack the tensor along
size: int
the size of the array to be used if shape inference resulted in None
Returns: TensorArray
the unpacked TensorArray
"""
shape = value.get_shape().as_list()
rank = len(shape)
dtype = value.dtype
array_size = shape[axis] if not shape[axis] is None else size
if array_size is None:
raise ValueError("Can't create TensorArray with size None")
array = tf.TensorArray(dtype=dtype, size=array_size)
dim_permutation = [axis] + range(1, axis) + [0] + range(axis + 1, rank)
unpack_axis_major_value = tf.transpose(value, dim_permutation)
full_array = array.unpack(unpack_axis_major_value)
return full_array
评论列表
文章目录