def setup(dim, criterion='mmd', biased=True, streaming_est=False, opt_log=True,
linear_kernel=False, opt_sigma=False, init_log_sigma=0,
net_version='basic', hotelling_reg=0,
strat='nesterov_momentum', learning_rate=0.01, **opt_args):
input_p = T.matrix('input_p')
input_q = T.matrix('input_q')
mmd2_pq, obj, rep_p, net_p, net_q, log_sigma = make_network(
input_p, input_q, dim,
criterion=criterion, biased=biased, streaming_est=streaming_est,
opt_log=opt_log, linear_kernel=linear_kernel, log_sigma=init_log_sigma,
hotelling_reg=hotelling_reg, net_version=net_version)
params = lasagne.layers.get_all_params([net_p, net_q], trainable=True)
if opt_sigma:
params.append(log_sigma)
fn = getattr(lasagne.updates, strat)
updates = fn(obj, params, learning_rate=learning_rate, **opt_args)
print("Compiling...", file=sys.stderr, end='')
train_fn = theano.function(
[input_p, input_q], [mmd2_pq, obj], updates=updates)
val_fn = theano.function([input_p, input_q], [mmd2_pq, obj])
get_rep = theano.function([input_p], rep_p)
print("done", file=sys.stderr)
return params, train_fn, val_fn, get_rep, log_sigma
评论列表
文章目录