builder.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:KBOPrediction 作者: riceluxs1t 项目源码 文件源码
def dropout_selu(
        self,
        x, 
        keep_prob, 
        alpha= DROP_ALPHA, 
        fixedPointMean=0.0, 
        fixedPointVar=1.0, 
        noise_shape=None, 
        seed=None, 
        name=None, 
        training=False):
        """Dropout to a value with rescaling."""

        def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name):
            keep_prob = 1.0 - rate
            x = ops.convert_to_tensor(x, name="x")
            if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1:
                raise ValueError("keep_prob must be a scalar tensor or a float in the "
                                                 "range (0, 1], got %g" % keep_prob)
            keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob")
            keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())

            alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha")
            keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())

            if tensor_util.constant_value(keep_prob) == 1:
                return x

            noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x)
            random_tensor = keep_prob
            random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype)
            binary_tensor = math_ops.floor(random_tensor)
            ret = x * binary_tensor + alpha * (1-binary_tensor)

            a = tf.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * tf.pow(alpha-fixedPointMean,2) + fixedPointVar)))

            b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha)
            ret = a * ret + b
            ret.set_shape(x.get_shape())
            return ret

        with ops.name_scope(name, "dropout", [x]) as name:
            return utils.smart_cond(training,
                                    lambda: dropout_selu_impl(x, keep_prob, alpha, noise_shape, seed, name),
                                    lambda: array_ops.identity(x))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号