def get_output_for(self, input, deterministic=False, **kwargs):
if deterministic:
norm_features = (input-self.avg_batch_mean.dimshuffle(*self.dimshuffle_args)) / T.sqrt(1e-6 + self.avg_batch_var).dimshuffle(*self.dimshuffle_args)
else:
batch_mean = T.mean(input,axis=self.axes_to_sum).flatten()
centered_input = input-batch_mean.dimshuffle(*self.dimshuffle_args)
batch_var = T.mean(T.square(centered_input),axis=self.axes_to_sum).flatten()
batch_stdv = T.sqrt(1e-6 + batch_var)
norm_features = centered_input / batch_stdv.dimshuffle(*self.dimshuffle_args)
# BN updates
new_m = 0.9*self.avg_batch_mean + 0.1*batch_mean
new_v = 0.9*self.avg_batch_var + T.cast((0.1*input.shape[0])/(input.shape[0]-1),th.config.floatX)*batch_var
self.bn_updates = [(self.avg_batch_mean, new_m), (self.avg_batch_var, new_v)]
if hasattr(self, 'g'):
activation = norm_features*self.g.dimshuffle(*self.dimshuffle_args)
else:
activation = norm_features
if hasattr(self, 'b'):
activation += self.b.dimshuffle(*self.dimshuffle_args)
return self.nonlinearity(activation)
评论列表
文章目录