def test_concat(make_data):
"""Test concatenation layer."""
x, _, X = make_data
# This replicates the input layer behaviour
f = ab.InputLayer('X', n_samples=3)
g = ab.InputLayer('Y', n_samples=3)
catlayer = ab.Concat(f, g)
F, KL = catlayer(X=x, Y=x)
tc = tf.test.TestCase()
with tc.test_session():
forked = F.eval()
orig = X.eval()
assert forked.shape == orig.shape[0:2] + (2 * orig.shape[2],)
assert np.all(forked == np.dstack((orig, orig)))
assert KL.eval() == 0.0
评论列表
文章目录