def main(unused_argv):
# Prepare training and testing data
dbpedia = learn.datasets.load_dataset(
'dbpedia', test_with_fake_data=FLAGS.test_with_fake_data)
x_train = pandas.DataFrame(dbpedia.train.data)[1]
y_train = pandas.Series(dbpedia.train.target)
x_test = pandas.DataFrame(dbpedia.test.data)[1]
y_test = pandas.Series(dbpedia.test.target)
# Process vocabulary
char_processor = learn.preprocessing.ByteProcessor(MAX_DOCUMENT_LENGTH)
x_train = np.array(list(char_processor.fit_transform(x_train)))
x_test = np.array(list(char_processor.transform(x_test)))
# Build model
classifier = learn.Estimator(model_fn=char_rnn_model)
# Train and predict
classifier.fit(x_train, y_train, steps=100)
y_predicted = [
p['class'] for p in classifier.predict(
x_test, as_iterable=True)
]
score = metrics.accuracy_score(y_test, y_predicted)
print('Accuracy: {0:f}'.format(score))
评论列表
文章目录