def handle_shared_float32(tf):
"""
Set the default shared type for float32 tensor to CudaNdarrayType.
This function is intended to be called from use(gpu_index), not directly.
"""
if tf:
theano.compile.shared_constructor(float32_shared_constructor)
else:
theano.compile.shared_constructor(float32_shared_constructor, True)
assert (float32_shared_constructor not in
theano.compile.shared.constructors)
# We can't test the driver during import here as this cause circular
# import dependency. So we also test it in the file theano/__init__.py
评论列表
文章目录