def test_function_otype_inference_tensor_to_tensor(self):
infer = tdb._infer_tf_output_type_from_input_type
self.assertEqual(tdt.TensorType([]),
infer(tf.negative, tdt.TensorType([])))
self.assertEqual(tdt.TensorType([2, 3]),
infer(tf.negative, tdt.TensorType([2, 3])))
self.assertEqual(tdt.TensorType([], 'int32'),
infer(tf.negative, tdt.TensorType([], 'int32')))
self.assertEqual(tdt.TensorType([2, 3], 'int32'),
infer(tf.negative, tdt.TensorType([2, 3], 'int32')))
f = lambda x: tf.cast(x, 'int32')
self.assertEqual(tdt.TensorType([], 'int32'),
infer(f, tdt.TensorType([], 'float32')))
self.assertEqual(tdt.TensorType([2, 3], 'int32'),
infer(f, tdt.TensorType([2, 3], 'float64')))
评论列表
文章目录