layers.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:rllab 作者: rll 项目源码 文件源码
def get_output_for(self, input, phase='train', **kwargs):
        if phase == 'train':
            # Calculate the moments based on the individual batch.
            mean, variance = tf.nn.moments(input, self.axis, shift=self.moving_mean)
            # Update the moving_mean and moving_variance moments.
            update_moving_mean = moving_averages.assign_moving_average(
                self.moving_mean, mean, self.decay)
            update_moving_variance = moving_averages.assign_moving_average(
                self.moving_variance, variance, self.decay)
            # Make sure the updates are computed here.
            with tf.control_dependencies([update_moving_mean,
                                          update_moving_variance]):
                output = tf.nn.batch_normalization(
                    input, mean, variance, self.beta, self.gamma, self.epsilon)
        else:
            output = tf.nn.batch_normalization(
                input, self.moving_mean, self.moving_variance, self.beta, self.gamma, self.epsilon)
        output.set_shape(self.input_shape)
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号