keras_example.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:dsde-deep-learning 作者: broadinstitute 项目源码 文件源码
def logistic_regression():
    train, test, valid = load_data('mnist.pkl.gz')

    epochs = 3200
    num_labels = 10
    train_y = make_one_hot(train[1], num_labels)
    valid_y = make_one_hot(valid[1], num_labels)
    test_y = make_one_hot(test[1], num_labels)

    logistic_model = Sequential()
    logistic_model.add(Dense(10, activation='softmax', input_dim=784, name='mnist_templates'))
    logistic_model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
    logistic_model.summary()
    templates = logistic_model.layers[0].get_weights()[0]
    plot_templates(templates, 0)
    print('weights shape:', templates.shape)

    for e in range(epochs):
        trainidx = random.sample(range(0, train[0].shape[0]), 8192)
        x_batch = train[0][trainidx,:]
        y_batch = train_y[trainidx]
        logistic_model.train_on_batch(x_batch, y_batch)
        if e % 5 == 0:
            plot_templates(logistic_model.layers[0].get_weights()[0], e)

    print('Test set loss and accuracy:', logistic_model.evaluate(test[0], test_y))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号