def local_gpu_extract_diagonal(node):
"""
extract_diagonal(host_from_gpu()) -> host_from_gpu(extract_diagonal)
gpu_from_host(extract_diagonal) -> extract_diagonal(gpu_from_host)
"""
if (isinstance(node.op, nlinalg.ExtractDiag) and
isinstance(node.inputs[0].type,
theano.tensor.TensorType)):
inp = node.inputs[0]
if inp.owner and isinstance(inp.owner.op, HostFromGpu):
return [host_from_gpu(nlinalg.extract_diag(
as_cuda_ndarray_variable(inp)))]
if isinstance(node.op, GpuFromHost):
host_input = node.inputs[0]
if (host_input.owner and
isinstance(host_input.owner.op, nlinalg.ExtractDiag) and
isinstance(host_input.owner.inputs[0].type,
theano.tensor.TensorType)):
diag_node = host_input.owner
return [nlinalg.extract_diag(
as_cuda_ndarray_variable(diag_node.inputs[0]))]
return False
评论列表
文章目录