network.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:BDD_Driving_Model 作者: gy20073 项目源码 文件源码
def batch_normalization(self, input, name, scale_offset=True, relu=False, is_training=False):
        with tf.variable_scope(name) as scope:
            norm_params = {'decay':0.999, 'scale':scale_offset, 'epsilon':0.001, 'is_training':is_training,
                           'activation_fn':tf.nn.relu if relu else None}            
            if hasattr(self, 'data_dict'):
                param_inits={'moving_mean':self.get_saved_value('mean'),
                             'moving_variance':self.get_saved_value('variance')}
                if scale_offset:
                    param_inits['beta']=self.get_saved_value('offset')
                    param_inits['gamma']=self.get_saved_value('scale')

                shape = [input.get_shape()[-1]]
                for key in param_inits:
                    param_inits[key] = np.reshape(param_inits[key], shape)
                norm_params['param_initializers'] = param_inits
            # TODO: there might be a bug if reusing is enabled.
            return slim.batch_norm(input, **norm_params)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号