test_enum.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def finite_difference(eval_loss, delta=0.1):
    """
    Computes finite-difference approximation of all parameters.
    """
    params = pyro.get_param_store().get_all_param_names()
    assert params, "no params found"
    grads = {name: Variable(torch.zeros(pyro.param(name).size())) for name in params}
    for name in sorted(params):
        value = pyro.param(name).data
        for index in itertools.product(*map(range, value.size())):
            center = value[index]
            value[index] = center + delta
            pos = eval_loss()
            value[index] = center - delta
            neg = eval_loss()
            value[index] = center
            grads[name][index] = (pos - neg) / (2 * delta)
    return grads
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号