Conv3D.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def grad(self, inputs, output_gradients):
        V, W, b, d = inputs
        dCdH, = output_gradients
        # make all of these ops support broadcasting of scalar b to vector b and eplace the zeros_like in all their grads
        # print dCdH.broadcastable
        # print "dCdH.broadcastable"
        # quit(-1)
        # dCdH = printing.Print("dCdH = ",["shape"])

        # Make sure the broadcasting pattern of the gradient is the the same
        # as the initial variable
        dCdV = theano.tensor.nnet.convTransp3D(
            W, T.zeros_like(V[0, 0, 0, 0, :]), d, dCdH, V.shape[1:4])
        dCdV = T.patternbroadcast(dCdV, V.broadcastable)
        WShape = W.shape
        dCdW = theano.tensor.nnet.convGrad3D(V, d, WShape, dCdH)
        dCdW = T.patternbroadcast(dCdW, W.broadcastable)
        dCdb = T.sum(dCdH, axis=(0, 1, 2, 3))
        dCdb = T.patternbroadcast(dCdb, b.broadcastable)
        dCdd = grad_undefined(
            self, 3, inputs[3],
            "The gradient of Conv3D with respect to the convolution"
            " stride is undefined because Conv3D is only defined for"
            " integer strides.")

        if 'name' in dir(dCdH) and dCdH.name is not None:
            dCdH_name = dCdH.name
        else:
            dCdH_name = 'anon_dCdH'

        if 'name' in dir(V) and V.name is not None:
            V_name = V.name
        else:
            V_name = 'anon_V'

        if 'name' in dir(W) and W.name is not None:
            W_name = W.name
        else:
            W_name = 'anon_W'

        if 'name' in dir(b) and b.name is not None:
            b_name = b.name
        else:
            b_name = 'anon_b'

        dCdV.name = 'Conv3D_dCdV(dCdH=' + dCdH_name + ',V=' + V_name + ')'
        dCdW.name = ('Conv3D_dCdW(dCdH=' + dCdH_name + ',V=' + V_name +
                     ',W=' + W_name + ')')
        dCdb.name = ('Conv3D_dCdb(dCdH=' + dCdH_name + ',V=' + V_name +
                     ',W=' + W_name + ',b=' + b_name + ')')

        return [dCdV, dCdW, dCdb, dCdd]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号