def test_reshaped():
x = tf.zeros((5, 12))
@reshaped((4, 3))
def my_func(_, a):
with tf.control_dependencies([tf.assert_equal(tf.shape(a),
(5, 4, 3))]):
return tf.identity(a)
y = my_func(None, x)
with tf.Session() as sess:
sess.run(y)
评论列表
文章目录