learn_kernel.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:opt-mmd 作者: dougalsutherland 项目源码 文件源码
def setup(dim, criterion='mmd', biased=True, streaming_est=False, opt_log=True,
          linear_kernel=False, opt_sigma=False, init_log_sigma=0,
          net_version='basic', hotelling_reg=0,
          strat='nesterov_momentum', learning_rate=0.01, **opt_args):
    input_p = T.matrix('input_p')
    input_q = T.matrix('input_q')

    mmd2_pq, obj, rep_p, net_p, net_q, log_sigma = make_network(
        input_p, input_q, dim,
        criterion=criterion, biased=biased, streaming_est=streaming_est,
        opt_log=opt_log, linear_kernel=linear_kernel, log_sigma=init_log_sigma,
        hotelling_reg=hotelling_reg, net_version=net_version)

    params = lasagne.layers.get_all_params([net_p, net_q], trainable=True)
    if opt_sigma:
        params.append(log_sigma)
    fn = getattr(lasagne.updates, strat)
    updates = fn(obj, params, learning_rate=learning_rate, **opt_args)

    print("Compiling...", file=sys.stderr, end='')
    train_fn = theano.function(
        [input_p, input_q], [mmd2_pq, obj], updates=updates)
    val_fn = theano.function([input_p, input_q], [mmd2_pq, obj])
    get_rep = theano.function([input_p], rep_p)
    print("done", file=sys.stderr)

    return params, train_fn, val_fn, get_rep, log_sigma
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号