def fit(self, X, y=None):
self.n_features = y.shape[0]
self.weights['input'] = theano.shared(value=np.zeros((
self.n_features, X.shape[1], self.spatial[0], self.spatial[1]),
dtype=theano.config.floatX), name='w', borrow=True)
input = T.tensor4(name='input')
target = T.tensor4(name='target')
decay = T.scalar(name='decay')
xy = T.nnet.conv2d(input.transpose(1,0,2,3), target.transpose(1,0,2,3),
border_mode=self.pad, subsample=self.stride)
xx = T.sum(T.power(input, 2), axis=(0,2,3))
k = ifelse(self.hidden_matrices['input'] is None, )
lam = theano.shared(value=self._C, name='constrain', borrow=True)
prediction = T.nnet.conv2d(input, self.weights['input'],
border_mode=self.pad,
subsample=self.stride)
weights, _ = theano.scan(
fn=lambda a, k, c: a/(k+c), outputs_info=None,
sequences=[self.hidden_matrices['A'].transpose(1,0,2,3),
self.hidden_matrices['K']], non_sequences=lam)
new_weights = weights.transpose(1,0,2,3)
updates = [(self.hidden_matrices['K'],
self.hidden_matrices['K'].dot(decay)+xx),
(self.hidden_matrices['A'],
self.hidden_matrices['A'].dot(decay) + xy),
(self.weights['input'], new_weights)]
self.conv_fct['train'] = theano.function([input, target, decay],
prediction,
updates=updates)
self.conv_fct['predict'] = theano.function([input], prediction)
return self.conv_fct['train'](X, y, 1)
评论列表
文章目录