train.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:EKLAVYA 作者: shensq04 项目源码 文件源码
def training(config_info):
    data_folder = config_info['data_folder']
    func_path = config_info['func_path']
    embed_path = config_info['embed_path']
    tag = config_info['tag']
    data_tag = config_info['data_tag']
    process_num = int(config_info['process_num'])
    embed_dim = int(config_info['embed_dim'])
    max_length = int(config_info['max_length'])
    num_classes = int(config_info['num_classes'])
    epoch_num = int(config_info['epoch_num'])
    save_batch_num = int(config_info['save_batchs'])
    output_dir = config_info['output_dir']

    '''create model & log folder'''
    if os.path.exists(output_dir):
        pass
    else:
        os.mkdir(output_dir)
    model_basedir = os.path.join(output_dir, 'model')
    if os.path.exists(model_basedir):
        pass
    else:
        os.mkdir(model_basedir)
    log_basedir = os.path.join(output_dir, 'log')
    if tf.gfile.Exists(log_basedir):
        tf.gfile.DeleteRecursively(log_basedir)
    tf.gfile.MakeDirs(log_basedir)
    config_info['log_path'] = log_basedir
    print('Created all folders!')

    '''load dataset'''
    if data_tag == 'callee':
        my_data = dataset.Dataset(data_folder, func_path, embed_path, process_num, embed_dim, max_length, num_classes, tag)
    else: #caller
        my_data = dataset_caller.Dataset(data_folder, func_path, embed_path, process_num, embed_dim, max_length, num_classes, tag)

    print('Created the dataset!')

    with tf.Graph().as_default(), tf.Session() as session:
        # generate placeholder
        data_pl, label_pl, length_pl, keep_prob_pl = placeholder_inputs(num_classes, max_length, embed_dim)

        # generate model
        model = Model(session, my_data, config_info, data_pl, label_pl, length_pl, keep_prob_pl)
        print('Created the model!')

        while my_data._complete_epochs < epoch_num:
            model.train()
            if model.run_count % save_batch_num == 0:
                model.saver.save(session, os.path.join(model_basedir, 'model'), global_step = model.run_count)
                print('Saved the model ... %d' % model.run_count)
            else:
                pass
        model.train_writer.close()
        model.test_writer.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号