def ensure_float(val, default, name):
if val is None:
return default.clone()
if not isinstance(val, Variable):
val = constant(val)
if hasattr(val, 'ndim') and val.ndim == 0:
val = as_scalar(val)
if not isinstance(val.type, theano.scalar.Scalar):
raise TypeError("%s: expected a scalar value" % (name,))
if not val.type.dtype == 'float32':
raise TypeError("%s: type is not float32" % (name,))
return val
评论列表
文章目录