def as_scalar(x, name=None):
from ..tensor import TensorType, scalar_from_tensor
if isinstance(x, gof.Apply):
if len(x.outputs) != 1:
raise ValueError("It is ambiguous which output of a multi-output"
" Op has to be fetched.", x)
else:
x = x.outputs[0]
if isinstance(x, Variable):
if isinstance(x.type, Scalar):
return x
elif isinstance(x.type, TensorType) and x.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError("Variable type field must be a Scalar.", x, x.type)
try:
return constant(x)
except TypeError:
raise TypeError("Cannot convert %s to Scalar" % x, type(x))
评论列表
文章目录