def sample(self, x, K):
if x.ndim == 1:
x = x.reshape(1, x.shape[0])
hn = self.encode(x)
W = self.params[0]
ww = T.dot(W.T, W)
samples = []
for _ in range(K):
s = hn * (1. - hn)
jj = ww * s.dimshuffle(0, 'x', 1) * s.dimshuffle(0, 1, 'x')
alpha = self.srng.normal(size=hn.shape,
avg=0.,
std=self.sigma,
dtype=theano.config.floatX)
delta = (alpha.dimshuffle(0, 1, 'x')*jj).sum(1)
zn = self.decode(hn + delta)
hn = self.encode(zn)
# zn2 = self.decode(hn)
samples.append(zn.eval())
return samples
评论列表
文章目录