def _testSurrogateLoss(self, session, losses, expected_addl_terms, xs):
surrogate_loss = sg.surrogate_loss(losses)
expected_surrogate_loss = math_ops.add_n(losses + expected_addl_terms)
self.assertAllClose(*session.run([surrogate_loss, expected_surrogate_loss]))
# Test backprop
expected_grads = gradients_impl.gradients(ys=expected_surrogate_loss, xs=xs)
surrogate_grads = gradients_impl.gradients(ys=surrogate_loss, xs=xs)
self.assertEqual(len(expected_grads), len(surrogate_grads))
grad_values = session.run(expected_grads + surrogate_grads)
n_grad = len(expected_grads)
self.assertAllClose(grad_values[:n_grad], grad_values[n_grad:])
stochastic_graph_test.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录