basic_test.py 文件源码

python
阅读 52 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testVariableInitialization(self):
    # Check that a simple operation involving the TrainableVariable
    # matches the result of the corresponding operation in numpy
    np.random.seed(100)
    types = (tf.float16, tf.float32, tf.float64)
    tol = (1e-2, 1e-6, 1e-9)
    tolerance_map = dict(zip(types, tol))
    lhs_shape = [3, 4]
    rhs_shape = [4, 6]
    for dtype in types:
      x = tf.placeholder(dtype, shape=lhs_shape)
      var = snt.TrainableVariable(shape=rhs_shape,
                                  dtype=dtype,
                                  initializers={"w": _test_initializer()})
      y = tf.matmul(x, var())
      with self.test_session() as sess:
        lhs_matrix = np.random.randn(*lhs_shape)
        sess.run(tf.global_variables_initializer())
        product, w = sess.run([y, var.w], {x: lhs_matrix})
      self.assertAllClose(product,
                          np.dot(
                              lhs_matrix.astype(dtype.as_numpy_dtype),
                              w.astype(dtype.as_numpy_dtype)),
                          atol=tolerance_map[dtype],
                          rtol=tolerance_map[dtype])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号