model_util_test.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:studio 作者: studioml 项目源码 文件源码
def test_model_pipe_mnist_urls(self):

        (x_train, y_train), (x_test, y_test) = mnist.load_data()

        x_train = x_train.reshape(60000, 28, 28, 1)
        x_test = x_test.reshape(10000, 28, 28, 1)
        x_train = x_train.astype('float32')
        x_test = x_test.astype('float32')
        x_train /= 255
        x_test /= 255

        y_train = to_categorical(y_train, 10)
        y_test = to_categorical(y_test, 10)

        model = Sequential()

        model.add(Flatten(input_shape=(28, 28, 1)))
        model.add(Dense(128, activation='relu'))
        model.add(Dense(128, activation='relu'))

        model.add(Dense(10, activation='softmax'))

        no_epochs = 3
        batch_size = 32

        model.compile(loss='categorical_crossentropy', optimizer='adam')

        model.fit(
            x_train, y_train, validation_data=(
                x_test,
                y_test),
            epochs=no_epochs,
            batch_size=batch_size)

        pipe = model_util.ModelPipe()

        pipe.add(
            lambda url: six.moves.urllib.request.urlopen(url).read(),
            num_workers=2,
            timeout=10)
        pipe.add(lambda img: Image.open(BytesIO(img)))
        pipe.add(model_util.resize_to_model_input(model))
        pipe.add(lambda x: 1 - x)
        pipe.add(model, num_workers=1, batch_size=32, batcher=np.vstack)
        pipe.add(lambda x: np.argmax(x, axis=1))

        url5 = 'http://blog.otoro.net/assets/20160401/png/mnist_output_10.png'
        url2 = 'http://joshmontague.com/images/mnist-2.png'
        urlb = 'http://joshmontague.com/images/mnist-3.png'

        expected_output = {url5: 5, url2: 2}
        output = pipe({url5: url5, url2: url2, urlb: urlb})
        output = {k: v for k, v in six.iteritems(output) if v}

        self.assertEquals(output, expected_output)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号