def numerical_check(test, graph, wrt_vars, order=1):
backprop_graphs, numeric_grads = differentiate_n_times_num(graph, wrt_vars, order=order)
for wrt_var, graph_grad, num_grad in zip(wrt_vars, backprop_graphs, numeric_grads):
name = "num" + str(order) + "df_wrt_" + wrt_var.name
if graph.name == "extra_exp_op":
name += " as input to another op!!!"
with test.subTest(name):
print("---------- " + name + " ----------")
print("Backprop grad:", graph_grad())
print("Numeric grad:", num_grad)
broadcasted_grad = np.broadcast_to(graph_grad(), wrt_var().shape) # not necessarily the same shape
arrays_allclose(broadcasted_grad, num_grad)
评论列表
文章目录