def check_model(path=MODEL_PATH, file=SAMPLE_CSV_FILE, nsamples=2):
'''
see predictions generated for the training dataset
'''
# load model
model = load_model(path)
# load data
data, dic = get_data(file)
rows, questions, true_answers = encode_data(data, dic)
# visualize model graph
# plot_model(model, to_file='tableqa_model.png')
# predict answers
prediction = model.predict([rows[:nsamples], questions[:nsamples]])
print prediction
predicted_answers = [[np.argmax(character) for character in sample] for sample in prediction]
print predicted_answers
print true_answers[:nsamples]
# one hot encode answers
# true_answers = [to_categorical(answer, num_classes=len(dic)) for answer in answers[:nsamples]]
# decode chars from char ids int
inv_dic = {v: k for k, v in dic.iteritems()}
for i in xrange(nsamples):
print '\n'
# print 'Predicted answer: ' + ''.join([dic[char] for char in sample])
print 'Table: ' + ''.join([inv_dic[char_id] for char_id in rows[i] if char_id != 0])
print 'Question: ' + ''.join([inv_dic[char_id] for char_id in questions[i] if char_id != 0])
print 'Answer(correct): ' + ''.join([inv_dic[char_id] for char_id in true_answers[i] if char_id != 0])
print 'Answer(predicted): ' + ''.join([inv_dic[char_id] for char_id in predicted_answers[i] if char_id != 0])
python类plot_model()的实例源码
def rcnn_mtl(processed_datasets, index_embedding, params):
start = datetime.datetime.now()
x_trains, y_trains, x_tests, y_tests = processed_datasets
mtl_model, single_models = build_models(params, index_embedding)
print(mtl_model.summary())
# plot_model(mtl_model, to_file='mtl_model.png', show_shapes=True)
itera = 0
batch_input = {}
batch_output = {}
batch_size = params['batch_size']
iterations = params['iterations']
sys.stdout.write('\ntotal iterations: {}'.format(iterations))
while (itera < iterations):
generate_batch_data(batch_input, batch_output, batch_size, x_trains, y_trains)
mtl_model.train_on_batch(batch_input, batch_output)
itera += 1
if (itera > 200 and itera % 100 == 0):
sys.stdout.write('\n\ncurrent iteration: {}'.format(itera))
# evaluate(single_models, x_trains, y_trains, 'train')
evaluate(single_models, x_tests, y_tests, 'test')
sys.stdout.flush()
if (itera >= 500):
save_predictions(single_models, x_tests, params['prediction_path'])
# save_models(single_models, params['save_model_path'])
end = datetime.datetime.now()
sys.stdout.write('\nused time: {}\n'.format(end - start))
def train_and_evaluate(train, test, intents_lookup, save=False):
validation_data = None
train_inputs, train_labels = prepare_inputs_and_outputs(train, intents_lookup)
if test:
test_inputs, test_labels = prepare_inputs_and_outputs(test, intents_lookup)
validation_data = test_inputs, test_labels
print('Number of sentences for each intent, train and test')
print([key for key in intents_lookup])
print(train_labels.sum(axis=0))
if test:
print(test_labels.sum(axis=0))
model = create_model(len(intents_lookup))
# first iteration
# model.summary()
# this requires graphviz binaries also
#plot_model(model, to_file=MODEL_OUTPUT_FOLDER + '/model.png', show_shapes=True)
history = model.fit(train_inputs, train_labels, validation_data=validation_data, epochs=MAX_ITERATIONS, batch_size=50)
# keep only f1_scores
history = {'train': np.array(history.history['f1_score']), 'test': np.array(history.history.get('val_f1_score', []))}
# compute f1 score weighted by support
y_pred_train = model.predict(train_inputs)
f1_train = f1_score(train_labels.argmax(axis=1),
y_pred_train.argmax(axis=1), average='weighted')
if test:
y_pred_test = model.predict(test_inputs)
f1_test = f1_score(test_labels.argmax(axis=1),
y_pred_test.argmax(axis=1), average='weighted')
else:
f1_test = None
# generate confusion matrix
# confusion = utils.my_confusion_matrix(test_labels.argmax(
# axis=1), y_pred_test.argmax(axis=1), len(intents_lookup))
print(f1_test, f1_train)
if save:
model.save(MODEL_OUTPUT_FOLDER + '/model.h5')
stats = {}
stats['model_name'] = MODEL_NAME
stats['model'] = model.get_config()
with open(MODEL_OUTPUT_FOLDER+'/stats.json', 'w+') as stats_file:
json.dump(stats, stats_file)
return history, f1_test, f1_train