BaseModel.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:kaggle-review 作者: daxiongshu 项目源码 文件源码
def _batch_normalization(self, x, layer_name, eps=0.001):
        with tf.variable_scope(layer_name.split('/')[-1]):
            beta, gamma, mean, variance = self._get_batch_normalization_weights(layer_name)
            # beta, gamma, mean, variance are numpy arrays!!!

            if beta is None:
                try:
                    net = tf.layers.batch_normalization(x, epsilon = eps)
                except:
                    net = tf.nn.batch_normalization(x, 0, 1, 0, 1, 0.01)
            else:
                try:
                    net = tf.layers.batch_normalization(x, epsilon = eps,        
                        beta_initializer = tf.constant_initializer(value=beta,dtype=tf.float32),
                        gamma_initializer = tf.constant_initializer(value=gamma,dtype=tf.float32),
                        moving_mean_initializer = tf.constant_initializer(value=mean,dtype=tf.float32),
                        moving_variance_initializer = tf.constant_initializer(value=variance,dtype=tf.float32), 
                    )
                except:
                    net = tf.nn.batch_normalization(x, mean, variance, beta, gamma, 0.01)
        mean = '%s/batch_normalization/moving_mean:0'%(layer_name)
        variance = '%s/batch_normalization/moving_variance:0'%(layer_name)
        try:
            tf.add_to_collection(tf.GraphKeys.SAVE_TENSORS, tf.get_default_graph().get_tensor_by_name(mean))
            tf.add_to_collection(tf.GraphKeys.SAVE_TENSORS, tf.get_default_graph().get_tensor_by_name(variance))
        except:
            pass
        return net
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号