def finite_difference(eval_loss, delta=0.1):
"""
Computes finite-difference approximation of all parameters.
"""
params = pyro.get_param_store().get_all_param_names()
assert params, "no params found"
grads = {name: Variable(torch.zeros(pyro.param(name).size())) for name in params}
for name in sorted(params):
value = pyro.param(name).data
for index in itertools.product(*map(range, value.size())):
center = value[index]
value[index] = center + delta
pos = eval_loss()
value[index] = center - delta
neg = eval_loss()
value[index] = center
grads[name][index] = (pos - neg) / (2 * delta)
return grads
评论列表
文章目录