utility.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Neural-Turing-Machine 作者: camigord 项目源码 文件源码
def pack_into_tensor(array, axis):
    """
    packs a given TensorArray into a tensor along a given axis

    Parameters:
    ----------
    array: TensorArray
        the tensor array to pack
    axis: int
        the axis to pack the array along

    Returns: Tensor
        the packed tensor
    """

    packed_tensor = array.pack()
    shape = packed_tensor.get_shape()
    rank = len(shape)

    dim_permutation = [axis] + range(1, axis) + [0]  + range(axis + 1, rank)
    correct_shape_tensor = tf.transpose(packed_tensor, dim_permutation)

    return correct_shape_tensor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号