blocks_test.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def test_function_otype_inference_tensor_to_tensor(self):
    infer = tdb._infer_tf_output_type_from_input_type

    self.assertEqual(tdt.TensorType([]),
                     infer(tf.negative, tdt.TensorType([])))
    self.assertEqual(tdt.TensorType([2, 3]),
                     infer(tf.negative, tdt.TensorType([2, 3])))

    self.assertEqual(tdt.TensorType([], 'int32'),
                     infer(tf.negative, tdt.TensorType([], 'int32')))
    self.assertEqual(tdt.TensorType([2, 3], 'int32'),
                     infer(tf.negative, tdt.TensorType([2, 3], 'int32')))

    f = lambda x: tf.cast(x, 'int32')
    self.assertEqual(tdt.TensorType([], 'int32'),
                     infer(f, tdt.TensorType([], 'float32')))
    self.assertEqual(tdt.TensorType([2, 3], 'int32'),
                     infer(f, tdt.TensorType([2, 3], 'float64')))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号