def update_opt(self, f, target, inputs, reg_coeff):
self.target = target
self.reg_coeff = reg_coeff
params = target.get_params(trainable=True)
constraint_grads = theano.grad(
f, wrt=params, disconnected_inputs='warn')
xs = tuple([ext.new_tensor_like("%s x" % p.name, p) for p in params])
def Hx_plain():
Hx_plain_splits = TT.grad(
TT.sum([TT.sum(g * x)
for g, x in zip(constraint_grads, xs)]),
wrt=params,
disconnected_inputs='warn'
)
return TT.concatenate([TT.flatten(s) for s in Hx_plain_splits])
self.opt_fun = ext.lazydict(
f_Hx_plain=lambda: ext.compile_function(
inputs=inputs + xs,
outputs=Hx_plain(),
log_name="f_Hx_plain",
),
)
conjugate_gradient_optimizer.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录