def test_function_otype_inference_raises(self):
def infer(result):
itype = tdt.TensorType([])
f = lambda _: result
return tdb._infer_tf_output_type_from_input_type(f, itype)
self.assertRaisesWithLiteralMatch(
TypeError, '42 is not a TF tensor', infer, 42)
six.assertRaisesRegex(
self, TypeError, 'unspecified rank', infer, tf.placeholder('float32'))
six.assertRaisesRegex(
self, TypeError, 'expected a batch tensor, saw a scalar', infer,
tf.placeholder('float32', []))
six.assertRaisesRegex(
self, TypeError, r'leading \(batch\) dimension should be None', infer,
tf.placeholder('float32', [0, 2]))
six.assertRaisesRegex(
self, TypeError, 'instance shape is not fully defined', infer,
tf.placeholder('float32', [None, 42, None, 5]))
评论列表
文章目录