def testShuffle(self):
shuffle_module = tf.load_op_library('shuffle_op.so')
shuffle = shuffle_module.shuffle
input_tensor = np.arange(12).reshape((3, 4))
desired_shape = np.array([6, -1])
output_tensor = input_tensor.reshape((6, 2))
with self.test_session():
result = shuffle(input_tensor, desired_shape)
self.assertAllEqual(result.eval(), output_tensor)
input_tensor = np.arange(12).reshape((3, 4))
desired_shape = np.array([5, -1])
output_tensor = input_tensor.reshape((6, 2))[:-1]
with self.test_session():
result = shuffle(input_tensor, desired_shape)
self.assertAllEqual(result.eval(), output_tensor)
shuffle_op_test.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录