def get_corrupted_input(rng, input, corruption_level, ntype='zeromask'):
''' depending on requirement, returns input corrupted by zeromask/gaussian/salt&pepper'''
MRG = RNG_MRG.MRG_RandomStreams(rng.randint(2 ** 30))
#theano_rng = RandomStreams()
if corruption_level == 0.0:
return input
if ntype=='zeromask':
return MRG.binomial(size=input.shape, n=1, p=1-corruption_level,dtype=theano.config.floatX) * input
elif ntype=='gaussian':
return input + MRG.normal(size = input.shape, avg = 0.0,
std = corruption_level, dtype = theano.config.floatX)
elif ntype=='salt_pepper':
# salt and pepper noise
print 'DAE uses salt and pepper noise'
a = MRG.binomial(size=input.shape, n=1,\
p=1-corruption_level,dtype=theano.config.floatX)
b = MRG.binomial(size=input.shape, n=1,\
p=corruption_level,dtype=theano.config.floatX)
c = T.eq(a,0) * b
return input * a + c
评论列表
文章目录