def test_relation_layer(self):
component = GANComponent(gan=gan, config={'test':True})
with self.test_session():
constant = tf.zeros([1, 2, 2, 1])
split = component.split_by_width_height(constant)
self.assertEqual(len(split), 4)
permute = component.permute(split, 2)
self.assertEqual(len(permute), 12)
rel_layer = component.relation_layer(constant)
self.assertEqual(gan.ops.shape(rel_layer), [1,2,2,1])
constant = tf.zeros([1, 4, 4, 1])
split = component.split_by_width_height(constant)
self.assertEqual(len(split), 16)
permute = component.permute(split, 2)
self.assertEqual(len(permute), 240)
rel_layer = component.relation_layer(constant)
self.assertEqual(gan.ops.shape(rel_layer), [1,4,4,1])
评论列表
文章目录