siamese_net.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:tensorflow-siamese-fc 作者: www0wwwjs1 项目源码 文件源码
def batchNorm(self, x, isTraining):
        shape = x.get_shape()
        paramsShape = shape[-1:]

        axis = list(range(len(shape)-1))

        with tf.variable_scope('bn'):
            beta = self.getVariable('beta', paramsShape, initializer=tf.constant_initializer(value=0, dtype=tf.float32))
            self.learningRates[beta.name] = 1.0
            gamma = self.getVariable('gamma', paramsShape, initializer=tf.constant_initializer(value=1, dtype=tf.float32))
            self.learningRates[gamma.name] = 2.0
            movingMean = self.getVariable('moving_mean', paramsShape, initializer=tf.constant_initializer(value=0, dtype=tf.float32), trainable=False)
            movingVariance = self.getVariable('moving_variance', paramsShape, initializer=tf.constant_initializer(value=1, dtype=tf.float32), trainable=False)

        mean, variance = tf.nn.moments(x, axis)
        updateMovingMean = moving_averages.assign_moving_average(movingMean, mean, MOVING_AVERAGE_DECAY)
        updateMovingVariance = moving_averages.assign_moving_average(movingVariance, variance, MOVING_AVERAGE_DECAY)
        tf.add_to_collection(UPDATE_OPS_COLLECTION, updateMovingMean)
        tf.add_to_collection(UPDATE_OPS_COLLECTION, updateMovingVariance)

        mean, variance = control_flow_ops.cond(isTraining, lambda : (mean, variance), lambda : (movingMean, movingVariance))

        x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, variance_epsilon=0.001)

        return x

    # def batchNormalization(self, inputs, isTraining, name):
    #     with tf.variable_scope('bn'):
    #         output = tf.contrib.layers.batch_norm(inputs, center=True, scale=True, is_training=isTraining, decay=0.997, epsilon=0.0001)
    #     self.learningRates[name+'/bn/BatchNorm/gamma:0'] = 2.0
    #     self.learningRates[name+'/bn/BatchNorm/beta:0'] = 1.0
    #
    #     return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号