theano_funcs.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:convnet-for-geometric-matching 作者: hjweide 项目源码 文件源码
def create_train_func(layers):
    Xa, Xb = T.tensor4('Xa'), T.tensor4('Xb')
    Xa_batch, Xb_batch = T.tensor4('Xa_batch'), T.tensor4('Xb_batch')

    Tp = get_output(
        layers['trans'],
        inputs={
            layers['inputa']: Xa, layers['inputb']: Xb,
        }, deterministic=False,
    )

    # transforms: ground-truth, predicted
    Tg = T.fmatrix('Tg')
    Tg_batch = T.fmatrix('Tg_batch')
    theta_gt = Tg.reshape((-1, 2, 3))
    theta_pr = Tp.reshape((-1, 2, 3))

    # grids: ground-truth, predicted
    Gg = T.dot(theta_gt, _meshgrid(20, 20))
    Gp = T.dot(theta_pr, _meshgrid(20, 20))

    train_loss = T.mean(T.sqr(Gg - Gp))

    params = get_all_params(layers['trans'], trainable=True)
    updates = nesterov_momentum(train_loss, params, 1e-3, 0.9)

    corr_func = theano.function(
        inputs=[theano.In(Xa_batch), theano.In(Xb_batch), theano.In(Tg_batch)],
        outputs=[Tp, train_loss],
        updates=updates,
        givens={
            Xa: Xa_batch, Xb: Xb_batch,  # Ia, Ib
            Tg: Tg_batch,                # transform Ia --> Ib
        }
    )

    return corr_func
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号