def testVariableInitialization(self):
# Check that a simple operation involving the TrainableVariable
# matches the result of the corresponding operation in numpy
np.random.seed(100)
types = (tf.float16, tf.float32, tf.float64)
tol = (1e-2, 1e-6, 1e-9)
tolerance_map = dict(zip(types, tol))
lhs_shape = [3, 4]
rhs_shape = [4, 6]
for dtype in types:
x = tf.placeholder(dtype, shape=lhs_shape)
var = snt.TrainableVariable(shape=rhs_shape,
dtype=dtype,
initializers={"w": _test_initializer()})
y = tf.matmul(x, var())
with self.test_session() as sess:
lhs_matrix = np.random.randn(*lhs_shape)
sess.run(tf.global_variables_initializer())
product, w = sess.run([y, var.w], {x: lhs_matrix})
self.assertAllClose(product,
np.dot(
lhs_matrix.astype(dtype.as_numpy_dtype),
w.astype(dtype.as_numpy_dtype)),
atol=tolerance_map[dtype],
rtol=tolerance_map[dtype])
评论列表
文章目录