def local_conv2d_cpu(node):
if not isinstance(node.op, AbstractConv2d):
return None
img, kern = node.inputs
if ((not isinstance(img.type, TensorType) or
not isinstance(kern.type, TensorType))):
return None
if node.op.border_mode not in ['full', 'valid']:
return None
if not node.op.filter_flip:
# Not tested yet
return None
rval = conv2d(img, kern,
node.op.imshp, node.op.kshp,
border_mode=node.op.border_mode,
subsample=node.op.subsample)
copy_stack_trace(node.outputs[0], rval)
return [rval]
评论列表
文章目录