def testing(config_info):
data_folder = config_info['data_folder']
func_path = config_info['func_path']
embed_path = config_info['embed_path']
tag = config_info['tag']
data_tag = config_info['data_tag']
process_num = int(config_info['process_num'])
embed_dim = int(config_info['embed_dim'])
max_length = int(config_info['max_length'])
num_classes = int(config_info['num_classes'])
model_dir = config_info['model_dir']
output_dir = config_info['output_dir']
'''create model & log folder'''
if os.path.exists(output_dir):
pass
else:
os.mkdir(output_dir)
print('Created all folders!')
'''load dataset'''
if data_tag == 'callee':
my_data = dataset.Dataset(data_folder, func_path, embed_path, process_num, embed_dim, max_length, num_classes, tag)
else: # caller
my_data = dataset_caller.Dataset(data_folder, func_path, embed_path, process_num, embed_dim, max_length, num_classes, tag)
print('Created the dataset!')
'''get model id list'''
# model_id_list = sorted(get_model_id_list(model_dir), reverse=True)
model_id_list = sorted(get_model_id_list(model_dir))
with tf.Graph().as_default(), tf.Session() as session:
# generate placeholder
data_pl, label_pl, length_pl, keep_prob_pl = placeholder_inputs(num_classes, max_length, embed_dim)
# generate model
model = Model(session, my_data, config_info, data_pl, label_pl, length_pl, keep_prob_pl)
print('Created the model!')
for model_id in model_id_list:
result_path = os.path.join(output_dir, 'test_result_%d.pkl' % model_id)
if os.path.exists(result_path):
continue
else:
pass
model_path = os.path.join(model_dir, 'model-%d' % model_id)
model.saver.restore(session, model_path)
total_result = model.test()
my_data._index_in_test = 0
my_data.test_tag = True
with open(result_path, 'w') as f:
pickle.dump(total_result, f)
print('Save the test result !!! ... %s' % result_path)
评论列表
文章目录