def as_floatX(variable):
"""
This code is taken from pylearn2:
Casts a given variable into dtype config.floatX
numpy ndarrays will remain numpy ndarrays
python floats will become 0-D ndarrays
all other types will be treated as theano tensors
"""
if isinstance(variable, float):
return numpy.cast[theano.config.floatX](variable)
if isinstance(variable, numpy.ndarray):
return numpy.cast[theano.config.floatX](variable)
return theano.tensor.cast(variable, theano.config.floatX)
评论列表
文章目录