def wrapped_conv(*args, **kwargs):
copy = dict(kwargs)
copy.pop("image_shape", None)
copy.pop("filter_shape", None)
assert copy.pop("filter_flip", False)
input, W, input_shape, get_W_shape = args
if theano.config.device == 'cpu':
return theano.tensor.nnet.conv2d(*args, **kwargs)
try:
return theano.sandbox.cuda.dnn.dnn_conv(
input.astype('float32'),
W.astype('float32'),
**copy
)
except Exception as e:
print("falling back to default conv2d")
return theano.tensor.nnet.conv2d(*args, **kwargs)
评论列表
文章目录