convolution.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:pyextremelm 作者: tobifinn 项目源码 文件源码
def fit(self, X, y=None):
        self.n_features = y.shape[0]
        self.weights['input'] = theano.shared(value=np.zeros((
            self.n_features, X.shape[1], self.spatial[0], self.spatial[1]),
            dtype=theano.config.floatX), name='w', borrow=True)
        input = T.tensor4(name='input')
        target = T.tensor4(name='target')
        decay = T.scalar(name='decay')
        xy = T.nnet.conv2d(input.transpose(1,0,2,3), target.transpose(1,0,2,3),
                           border_mode=self.pad, subsample=self.stride)
        xx = T.sum(T.power(input, 2), axis=(0,2,3))
        k = ifelse(self.hidden_matrices['input'] is None, )

        lam = theano.shared(value=self._C, name='constrain', borrow=True)
        prediction = T.nnet.conv2d(input, self.weights['input'],
                                   border_mode=self.pad,
                                   subsample=self.stride)
        weights, _ = theano.scan(
            fn=lambda a, k, c: a/(k+c), outputs_info=None,
            sequences=[self.hidden_matrices['A'].transpose(1,0,2,3),
                       self.hidden_matrices['K']], non_sequences=lam)
        new_weights = weights.transpose(1,0,2,3)
        updates = [(self.hidden_matrices['K'],
                    self.hidden_matrices['K'].dot(decay)+xx),
                   (self.hidden_matrices['A'],
                    self.hidden_matrices['A'].dot(decay) + xy),
                   (self.weights['input'], new_weights)]
        self.conv_fct['train'] = theano.function([input, target, decay],
                                                 prediction,
                                                 updates=updates)
        self.conv_fct['predict'] = theano.function([input], prediction)
        return self.conv_fct['train'](X, y, 1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号