network.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:shuttleNet 作者: shiyemin 项目源码 文件源码
def batch_normalization(self, input, name,
                            scale_offset=True,
                            relu=False,
                            decay=0.999,
                            moving_vars='moving_vars'):
        # NOTE: Currently, only inference is supported
        with tf.variable_scope(name):
            axis = list(range(len(input.get_shape()) - 1))
            shape = [input.get_shape()[-1]]
            if scale_offset:
                scale = self.make_var('scale', shape=shape,
                                 initializer=tf.ones_initializer(),
                                 trainable=self.trainable)
                offset = self.make_var('offset', shape=shape,
                                initializer=tf.zeros_initializer(),
                                trainable=self.trainable)
            else:
                scale, offset = (None, None)
            # Create moving_mean and moving_variance add them to
            # GraphKeys.MOVING_AVERAGE_VARIABLES collections.
            moving_collections = [moving_vars, tf.GraphKeys.MOVING_AVERAGE_VARIABLES]
            moving_mean = self.make_var('mean',
                                            shape,
                                            initializer=tf.zeros_initializer(),
                                            trainable=False,
                                            collections=moving_collections)
            moving_variance = self.make_var('variance',
                                                shape,
                                                initializer=tf.ones_initializer(),
                                                trainable=False,
                                                collections=moving_collections)
            if self.trainable:
                # Calculate the moments based on the individual batch.
                mean, variance = tf.nn.moments(input, axis)

                update_moving_mean = moving_averages.assign_moving_average(
                    moving_mean, mean, decay)
                tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
                update_moving_variance = moving_averages.assign_moving_average(
                    moving_variance, variance, decay)
                tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
            else:
                # Just use the moving_mean and moving_variance.
                mean = moving_mean
                variance = moving_variance
            output = tf.nn.batch_normalization(
                input,
                mean=mean,
                variance=variance,
                offset=offset,
                scale=scale,
                # TODO: This is the default Caffe batch norm eps
                # Get the actual eps from parameters
                variance_epsilon=1e-5,
                name=name)
            if relu:
                output = tf.nn.relu(output)
            return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号