def create_train_func(layers):
Xa, Xb = T.tensor4('Xa'), T.tensor4('Xb')
Xa_batch, Xb_batch = T.tensor4('Xa_batch'), T.tensor4('Xb_batch')
Tp = get_output(
layers['trans'],
inputs={
layers['inputa']: Xa, layers['inputb']: Xb,
}, deterministic=False,
)
# transforms: ground-truth, predicted
Tg = T.fmatrix('Tg')
Tg_batch = T.fmatrix('Tg_batch')
theta_gt = Tg.reshape((-1, 2, 3))
theta_pr = Tp.reshape((-1, 2, 3))
# grids: ground-truth, predicted
Gg = T.dot(theta_gt, _meshgrid(20, 20))
Gp = T.dot(theta_pr, _meshgrid(20, 20))
train_loss = T.mean(T.sqr(Gg - Gp))
params = get_all_params(layers['trans'], trainable=True)
updates = nesterov_momentum(train_loss, params, 1e-3, 0.9)
corr_func = theano.function(
inputs=[theano.In(Xa_batch), theano.In(Xb_batch), theano.In(Tg_batch)],
outputs=[Tp, train_loss],
updates=updates,
givens={
Xa: Xa_batch, Xb: Xb_batch, # Ia, Ib
Tg: Tg_batch, # transform Ia --> Ib
}
)
return corr_func
theano_funcs.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录