def test_shape():
input_var = T.tensor3('input')
target_var = T.imatrix('target')
output_var, _, _ = memory_augmented_neural_network(
input_var, target_var,
batch_size=16,
nb_class=5,
memory_shape=(128, 40),
controller_size=200,
input_size=20 * 20,
nb_reads=4)
posterior_fn = theano.function([input_var, target_var], output_var)
test_input = np.random.rand(16, 50, 20 * 20)
test_target = np.random.randint(5, size=(16, 50)).astype('int32')
test_input_invalid_batch_size = np.random.rand(16 + 1, 50, 20 * 20)
test_input_invalid_depth = np.random.rand(16, 50, 20 * 20 - 1)
test_output = posterior_fn(test_input, test_target)
assert test_output.shape == (16, 50, 5)
with pytest.raises(ValueError) as e_info:
posterior_fn(test_input_invalid_batch_size, test_target)
with pytest.raises(ValueError) as e_info:
posterior_fn(test_input_invalid_depth, test_target)
评论列表
文章目录