nn_test.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:imperative 作者: yaroslavvb 项目源码 文件源码
def testL2Normalize(self):
    x_shape = [20]
    np.random.seed(1)
    x_np = np.random.random_sample(x_shape).astype(np.float32)
    for dim in range(len(x_shape)):
      y_np = self._l2Normalize(x_np, dim)
      with self.test_session():
        x_tf = tf.constant(x_np, name="x")
        y_tf = tf.nn.l2_normalize(x_tf, dim)
        self.assertAllClose(y_np, y_tf.eval())

  # def testL2NormalizeGradient(self):
  #   x_shape = [20, 7, 3]
  #   np.random.seed(1)
  #   x_np = np.random.random_sample(x_shape).astype(np.float64)
  #   for dim in range(len(x_shape)):
  #     with self.test_session():
  #       x_tf = tf.constant(x_np, name="x")
  #       y_tf = tf.nn.l2_normalize(x_tf, dim)
  #       err = tf.test.compute_gradient_error(x_tf, x_shape, y_tf, x_shape)
  #     print("L2Normalize gradient err = %g " % err)
  #     self.assertLess(err, 1e-4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号