scan_utils.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def traverse(out, x, x_copy, d, visited=None):
    """
    Function used by scan to parse the tree and figure out which nodes
    it needs to replace.

    There are two options :
        1) x and x_copy or on host, then you would replace x with x_copy
        2) x is on gpu, x_copy on host, then you need to replace
        host_from_gpu(x) with x_copy
    This happens because initially shared variables are on GPU... which is
    fine for the main computational graph but confuses things a bit for the
    inner graph of scan.

    """
    # ``visited`` is a set of nodes that are already known and don't need to be
    # checked again, speeding up the traversal of multiply-connected graphs.
    # if a ``visited`` set is given, it will be updated in-place so the callee
    # knows which nodes we have seen.
    if visited is None:
        visited = set()
    if out in visited:
        return d
    visited.add(out)
    from theano.sandbox import cuda
    from theano.gpuarray.basic_ops import gpu_from_host, host_from_gpu
    from theano.gpuarray import pygpu_activated
    from theano.gpuarray.type import GpuArrayType
    if out == x:
        if isinstance(x.type, cuda.CudaNdarrayType):
            d[out] = cuda.gpu_from_host(x_copy)
        else:
            assert isinstance(x.type, GpuArrayType)
            d[out] = gpu_from_host(x.type.context_name)(x_copy)
        return d
    elif out.owner is None:
        return d
    elif (cuda.cuda_available and
          out.owner.op == cuda.host_from_gpu and
          out.owner.inputs == [x]):
        d[out] = tensor.as_tensor_variable(x_copy)
        return d
    elif (pygpu_activated and
          out.owner.op == host_from_gpu and
          out.owner.inputs == [x]):
        d[out] = tensor.as_tensor_variable(x_copy)
        return d
    else:
        for inp in out.owner.inputs:
            d = traverse(inp, x, x_copy, d, visited)
        return d


# Hashing a dictionary/list/tuple by xoring the hash of each element
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号