eval.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:EKLAVYA 作者: shensq04 项目源码 文件源码
def testing(config_info):
    data_folder = config_info['data_folder']
    func_path = config_info['func_path']
    embed_path = config_info['embed_path']
    tag = config_info['tag']
    data_tag = config_info['data_tag']
    process_num = int(config_info['process_num'])
    embed_dim = int(config_info['embed_dim'])
    max_length = int(config_info['max_length'])
    num_classes = int(config_info['num_classes'])
    model_dir = config_info['model_dir']
    output_dir = config_info['output_dir']

    '''create model & log folder'''
    if os.path.exists(output_dir):
        pass
    else:
        os.mkdir(output_dir)
    print('Created all folders!')

    '''load dataset'''
    if data_tag == 'callee':
        my_data = dataset.Dataset(data_folder, func_path, embed_path, process_num, embed_dim, max_length, num_classes, tag)
    else: # caller
        my_data = dataset_caller.Dataset(data_folder, func_path, embed_path, process_num, embed_dim, max_length, num_classes, tag)
    print('Created the dataset!')

    '''get model id list'''
    # model_id_list = sorted(get_model_id_list(model_dir), reverse=True)
    model_id_list = sorted(get_model_id_list(model_dir))

    with tf.Graph().as_default(), tf.Session() as session:
        # generate placeholder
        data_pl, label_pl, length_pl, keep_prob_pl = placeholder_inputs(num_classes, max_length, embed_dim)
        # generate model
        model = Model(session, my_data, config_info, data_pl, label_pl, length_pl, keep_prob_pl)
        print('Created the model!')

        for model_id in model_id_list:
            result_path = os.path.join(output_dir, 'test_result_%d.pkl' % model_id)
            if os.path.exists(result_path):
                continue
            else:
                pass
            model_path = os.path.join(model_dir, 'model-%d' % model_id)
            model.saver.restore(session, model_path)

            total_result = model.test()
            my_data._index_in_test = 0
            my_data.test_tag = True
            with open(result_path, 'w') as f:
                pickle.dump(total_result, f)
            print('Save the test result !!! ... %s' % result_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号