def batch_normalization(self, input, name, scale_offset=True, relu=False, is_training=False):
with tf.variable_scope(name) as scope:
norm_params = {'decay':0.999, 'scale':scale_offset, 'epsilon':0.001, 'is_training':is_training,
'activation_fn':tf.nn.relu if relu else None}
if hasattr(self, 'data_dict'):
param_inits={'moving_mean':self.get_saved_value('mean'),
'moving_variance':self.get_saved_value('variance')}
if scale_offset:
param_inits['beta']=self.get_saved_value('offset')
param_inits['gamma']=self.get_saved_value('scale')
shape = [input.get_shape()[-1]]
for key in param_inits:
param_inits[key] = np.reshape(param_inits[key], shape)
norm_params['param_initializers'] = param_inits
# TODO: there might be a bug if reusing is enabled.
return slim.batch_norm(input, **norm_params)
评论列表
文章目录