def corrupt(tensor, corruption_level=0.05):
"""Uses the masking noise algorithm to mask corruption_level proportion
of the input.
:param tensor: A tensor whose values are to be corrupted.
:param corruption_level: An int [0, 1] specifying the probability to corrupt each value.
:return: The corrupted tensor.
"""
total_samples = tf.reduce_prod(tf.shape(tensor))
corruption_matrix = tf.multinomial(tf.log([[corruption_level, 1 - corruption_level]]), total_samples)
corruption_matrix = tf.cast(tf.reshape(corruption_matrix, shape=tf.shape(tensor)), dtype=tf.float32)
return tf.mul(tensor, corruption_matrix)
评论列表
文章目录