def pack_into_tensor(array, axis):
"""
packs a given TensorArray into a tensor along a given axis
Parameters:
----------
array: TensorArray
the tensor array to pack
axis: int
the axis to pack the array along
Returns: Tensor
the packed tensor
"""
packed_tensor = array.pack()
shape = packed_tensor.get_shape()
rank = len(shape)
dim_permutation = [axis] + range(1, axis) + [0] + range(axis + 1, rank)
correct_shape_tensor = tf.transpose(packed_tensor, dim_permutation)
return correct_shape_tensor
评论列表
文章目录