def test_fc_raises(self):
six.assertRaisesRegex(
self, TypeError, 'FC input dtype must be float32', tdl.FC(1),
tf.constant([0], dtype='int64'))
six.assertRaisesRegex(
self, TypeError, 'FC input shape must be 1D', tdl.FC(1),
tf.constant(0, dtype='float32'))
fc = tdl.FC(1)
fc(tf.constant([[0]], 'float32'))
six.assertRaisesRegex(
self, TypeError, 'Type mismatch between input type', fc,
tf.constant([[0, 0]], 'float32'))
评论列表
文章目录