def local_abstractconv3d_gradweight_gemm(node):
if theano.config.cxx == "" or not theano.config.blas.ldflags:
return
if not isinstance(node.op, AbstractConv3d_gradWeights):
return None
img, topgrad, shape = node.inputs
if not isinstance(img.type, TensorType) or \
not isinstance(topgrad.type, TensorType):
return None
rval = Corr3dMM_gradWeights(border_mode=node.op.border_mode,
subsample=node.op.subsample,
filter_dilation=node.op.filter_dilation)(img, topgrad, shape)
copy_stack_trace(node.outputs[0], rval)
# need to flip the kernel if necessary
if node.op.filter_flip:
rval = rval[:, :, ::-1, ::-1, ::-1]
rval = theano.tensor.patternbroadcast(rval, node.outputs[0].broadcastable)
copy_stack_trace(node.outputs[0], rval)
return [rval]
评论列表
文章目录