def test_model_pipe_mnist_urls(self):
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 28, 28, 1)
x_test = x_test.reshape(10000, 28, 28, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
model = Sequential()
model.add(Flatten(input_shape=(28, 28, 1)))
model.add(Dense(128, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(10, activation='softmax'))
no_epochs = 3
batch_size = 32
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(
x_train, y_train, validation_data=(
x_test,
y_test),
epochs=no_epochs,
batch_size=batch_size)
pipe = model_util.ModelPipe()
pipe.add(
lambda url: six.moves.urllib.request.urlopen(url).read(),
num_workers=2,
timeout=10)
pipe.add(lambda img: Image.open(BytesIO(img)))
pipe.add(model_util.resize_to_model_input(model))
pipe.add(lambda x: 1 - x)
pipe.add(model, num_workers=1, batch_size=32, batcher=np.vstack)
pipe.add(lambda x: np.argmax(x, axis=1))
url5 = 'http://blog.otoro.net/assets/20160401/png/mnist_output_10.png'
url2 = 'http://joshmontague.com/images/mnist-2.png'
urlb = 'http://joshmontague.com/images/mnist-3.png'
expected_output = {url5: 5, url2: 2}
output = pipe({url5: url5, url2: url2, urlb: urlb})
output = {k: v for k, v in six.iteritems(output) if v}
self.assertEquals(output, expected_output)
评论列表
文章目录