def rcnn_mtl(processed_datasets, index_embedding, params):
start = datetime.datetime.now()
x_trains, y_trains, x_tests, y_tests = processed_datasets
mtl_model, single_models = build_models(params, index_embedding)
print(mtl_model.summary())
# plot_model(mtl_model, to_file='mtl_model.png', show_shapes=True)
itera = 0
batch_input = {}
batch_output = {}
batch_size = params['batch_size']
iterations = params['iterations']
sys.stdout.write('\ntotal iterations: {}'.format(iterations))
while (itera < iterations):
generate_batch_data(batch_input, batch_output, batch_size, x_trains, y_trains)
mtl_model.train_on_batch(batch_input, batch_output)
itera += 1
if (itera > 200 and itera % 100 == 0):
sys.stdout.write('\n\ncurrent iteration: {}'.format(itera))
# evaluate(single_models, x_trains, y_trains, 'train')
evaluate(single_models, x_tests, y_tests, 'test')
sys.stdout.flush()
if (itera >= 500):
save_predictions(single_models, x_tests, params['prediction_path'])
# save_models(single_models, params['save_model_path'])
end = datetime.datetime.now()
sys.stdout.write('\nused time: {}\n'.format(end - start))
评论列表
文章目录