def test_keras_model_probs(num_classes):
bounds = (0, 255)
channels = num_classes
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
inputs = Input(shape=(5, 5, channels))
logits = GlobalAveragePooling2D(
data_format='channels_last')(inputs)
probs = Activation(softmax)(logits)
model1 = KerasModel(
Model(inputs=inputs, outputs=logits),
bounds=bounds,
predicts='logits')
model2 = KerasModel(
Model(inputs=inputs, outputs=probs),
bounds=bounds,
predicts='probabilities')
model3 = KerasModel(
Model(inputs=inputs, outputs=probs),
bounds=bounds,
predicts='probs')
np.random.seed(22)
test_images = np.random.rand(2, 5, 5, channels).astype(np.float32)
p1 = model1.batch_predictions(test_images)
p2 = model2.batch_predictions(test_images)
p3 = model3.batch_predictions(test_images)
assert p1.shape == p2.shape == p3.shape == (2, num_classes)
np.testing.assert_array_almost_equal(
p1 - p1.max(),
p2 - p2.max(),
decimal=1)
np.testing.assert_array_almost_equal(
p2 - p2.max(),
p3 - p3.max(),
decimal=5)
评论列表
文章目录