def logistic_regression():
train, test, valid = load_data('mnist.pkl.gz')
epochs = 3200
num_labels = 10
train_y = make_one_hot(train[1], num_labels)
valid_y = make_one_hot(valid[1], num_labels)
test_y = make_one_hot(test[1], num_labels)
logistic_model = Sequential()
logistic_model.add(Dense(10, activation='softmax', input_dim=784, name='mnist_templates'))
logistic_model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
logistic_model.summary()
templates = logistic_model.layers[0].get_weights()[0]
plot_templates(templates, 0)
print('weights shape:', templates.shape)
for e in range(epochs):
trainidx = random.sample(range(0, train[0].shape[0]), 8192)
x_batch = train[0][trainidx,:]
y_batch = train_y[trainidx]
logistic_model.train_on_batch(x_batch, y_batch)
if e % 5 == 0:
plot_templates(logistic_model.layers[0].get_weights()[0], e)
print('Test set loss and accuracy:', logistic_model.evaluate(test[0], test_y))
keras_example.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录