def create_updates(self, input):
if self.mode == 0:
now_mean = T.mean(input, axis=0)
now_var = T.var(input, axis=0)
batch = T.cast(input.shape[0], theano.config.floatX)
else:
now_mean = T.mean(input, axis=(0,2,3))
now_var = T.var(input, axis=(0,2,3))
batch = T.cast(input.shape[0]*input.shape[2]*input.shape[3], theano.config.floatX)
if self.updates is None:
new_mean = self.momentum * self.mean + (1.0-self.momentum) * now_mean
new_var = self.momentum * self.var + (1.0-self.momentum) * ((batch+1.0)/batch*now_var)
else:
new_mean = self.momentum * self.updates[0][1] + (1.0-self.momentum) * now_mean
new_var = self.momentum * self.updates[1][1] + (1.0-self.momentum) * ((batch+1.0)/batch*now_var)
self.updates = [(self.mean, new_mean), (self.var, new_var)]
评论列表
文章目录