def training(config_info):
data_folder = config_info['data_folder']
func_path = config_info['func_path']
embed_path = config_info['embed_path']
tag = config_info['tag']
data_tag = config_info['data_tag']
process_num = int(config_info['process_num'])
embed_dim = int(config_info['embed_dim'])
max_length = int(config_info['max_length'])
num_classes = int(config_info['num_classes'])
epoch_num = int(config_info['epoch_num'])
save_batch_num = int(config_info['save_batchs'])
output_dir = config_info['output_dir']
'''create model & log folder'''
if os.path.exists(output_dir):
pass
else:
os.mkdir(output_dir)
model_basedir = os.path.join(output_dir, 'model')
if os.path.exists(model_basedir):
pass
else:
os.mkdir(model_basedir)
log_basedir = os.path.join(output_dir, 'log')
if tf.gfile.Exists(log_basedir):
tf.gfile.DeleteRecursively(log_basedir)
tf.gfile.MakeDirs(log_basedir)
config_info['log_path'] = log_basedir
print('Created all folders!')
'''load dataset'''
if data_tag == 'callee':
my_data = dataset.Dataset(data_folder, func_path, embed_path, process_num, embed_dim, max_length, num_classes, tag)
else: #caller
my_data = dataset_caller.Dataset(data_folder, func_path, embed_path, process_num, embed_dim, max_length, num_classes, tag)
print('Created the dataset!')
with tf.Graph().as_default(), tf.Session() as session:
# generate placeholder
data_pl, label_pl, length_pl, keep_prob_pl = placeholder_inputs(num_classes, max_length, embed_dim)
# generate model
model = Model(session, my_data, config_info, data_pl, label_pl, length_pl, keep_prob_pl)
print('Created the model!')
while my_data._complete_epochs < epoch_num:
model.train()
if model.run_count % save_batch_num == 0:
model.saver.save(session, os.path.join(model_basedir, 'model'), global_step = model.run_count)
print('Saved the model ... %d' % model.run_count)
else:
pass
model.train_writer.close()
model.test_writer.close()
评论列表
文章目录