def test_infer():
data_iter = Euclidean(batch_size=27, dim_in=17)
gbn = test_vae.test_build_GBN(dim_in=data_iter.dims[data_iter.name])
inference_args = dict(
n_inference_steps=7,
pass_gradients=True
)
gdir = test_build_gdir(gbn, **inference_args)
X = T.matrix('x', dtype=floatX)
rval, constants, updates = gdir.inference(X, X)
f = theano.function([X], rval.values(), updates=updates)
x = data_iter.next()[data_iter.name]
results, samples, full_results, updates = gdir(X, X)
f = theano.function([X], results.values(), updates=updates)
print f(x)
评论列表
文章目录