def find_nearest_instances(training_data_instances, training_data_labels, test_data_instances, test_data_labels):
start_time = time.time()
# speed using multiple processes
NUMBER_OF_PROCESSES = 4
processes = []
# shared by different processes, to be mentioned is that
# global variable is only read within processes
# the update of global variable within a process will not be submitted
classified_results = multiprocessing.Array('i', len(test_data_instances), lock = False)
test_data_subdivisions = range(0, len(test_data_instances) + 1,\
int(len(test_data_instances) / NUMBER_OF_PROCESSES))
test_data_subdivisions[-1] = len(test_data_instances)
for process_index in range(NUMBER_OF_PROCESSES):
process = multiprocessing.Process(target = find_nearest_instances_subprocess,
args = (training_data_instances,
training_data_labels,
test_data_instances,
test_data_subdivisions[process_index],
test_data_subdivisions[process_index + 1],
classified_results))
process.start()
processes.append(process)
print "Waiting..."
# wait until all processes are finished
for process in processes:
process.join()
print "Complete."
print "--- %s seconds ---" % (time.time() - start_time)
error_count = 0
confusion_matrix = np.zeros((10, 10), dtype=np.int)
for test_instance_index, classified_label in zip(range(len(test_data_instances)),\
classified_results):
if test_data_labels[test_instance_index] != classified_label:
error_count += 1
confusion_matrix[test_data_labels[test_instance_index]][classified_label] += 1
error_rate = 100.0 * error_count / len(test_data_instances)
return classified_results, error_rate, confusion_matrix
评论列表
文章目录