numerical_check.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:autodiff 作者: bgavran 项目源码 文件源码
def numerical_check(test, graph, wrt_vars, order=1):
    backprop_graphs, numeric_grads = differentiate_n_times_num(graph, wrt_vars, order=order)

    for wrt_var, graph_grad, num_grad in zip(wrt_vars, backprop_graphs, numeric_grads):
        name = "num" + str(order) + "df_wrt_" + wrt_var.name
        if graph.name == "extra_exp_op":
            name += " as input to another op!!!"
        with test.subTest(name):
            print("---------- " + name + " ----------")
            print("Backprop grad:", graph_grad())
            print("Numeric grad:", num_grad)
            broadcasted_grad = np.broadcast_to(graph_grad(), wrt_var().shape)  # not necessarily the same shape
            arrays_allclose(broadcasted_grad, num_grad)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号