utility.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Neural-Turing-Machine 作者: camigord 项目源码 文件源码
def unpack_into_tensorarray(value, axis, size=None):
    """
    unpacks a given tensor along a given axis into a TensorArray

    Parameters:
    ----------
    value: Tensor
        the tensor to be unpacked
    axis: int
        the axis to unpack the tensor along
    size: int
        the size of the array to be used if shape inference resulted in None

    Returns: TensorArray
        the unpacked TensorArray
    """

    shape = value.get_shape().as_list()
    rank = len(shape)
    dtype = value.dtype
    array_size = shape[axis] if not shape[axis] is None else size

    if array_size is None:
        raise ValueError("Can't create TensorArray with size None")

    array = tf.TensorArray(dtype=dtype, size=array_size)
    dim_permutation = [axis] + range(1, axis) + [0] + range(axis + 1, rank)
    unpack_axis_major_value = tf.transpose(value, dim_permutation)
    full_array = array.unpack(unpack_axis_major_value)

    return full_array
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号