dropout.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def fixed_dropout(xs, keep_prob, noise_shape, seed=None):
    """
    Apply dropout with same mask over all inputs
    Args:
        xs: list of tensors
        keep_prob:
        noise_shape:
        seed:

    Returns:
        list of dropped inputs
    """
    with tf.name_scope("dropout", values=xs):
        noise_shape = noise_shape
        # uniform [keep_prob, 1.0 + keep_prob)
        random_tensor = keep_prob
        random_tensor += tf.random_uniform(noise_shape, seed=seed, dtype=xs[0].dtype)
        # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
        binary_tensor = tf.floor(random_tensor)
        outputs = []
        for x in xs:
            ret = tf.div(x, keep_prob) * binary_tensor
            ret.set_shape(x.get_shape())
            outputs.append(ret)
        return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号