def main(task_num,sample_size=''):
embedding_size = 100
epoch = 300
best_accuracy = 0.0
grind_ris={}
if not os.path.exists('data/ris/task_{}'.format(task_num)):
os.makedirs('data/ris/task_{}'.format(task_num))
param_grid = {'nb': [20],
'lr': [0.001],
'tr': [[0,0,0,0]],
'L2': [0.001],# [0.0,0.1,0.01,0.001,0.0001]
'bz': [32],
'dr': [0.5],
}
grid = list(ParameterGrid(param_grid))
np.random.shuffle(grid)
for params in list(grid):
data = Dataset('data/tasks_1-20_v1-2/en-valid{}/'.format(sample_size),int(task_num))
## for sentence
par = get_parameters(data,epoch,data._data['sent_len'],data._data['sent_numb'],embedding_size,params)
t = train(epoch,params['bz'], data, par, dr=params['dr'], _test=True)
acc = sorted([v for k,v in t[5].items()])[-1]
if (acc > best_accuracy):
best_accuracy = acc
grind_ris[str(params)] = acc
f_save = 'data/ris/task_{}/{}.PIK'.format(task_num,str(params)+str(acc))
with open(f_save, 'w') as f:
pickle.dump((t), f)
# batch_size = 32
# epoch = 200
# if not os.path.exists('data/ris/task_{}'.format(task_num)):
# os.makedirs('data/ris/task_{}'.format(task_num))
# data = Dataset('data/tasks_1-20_v1-2/en-valid{}/'.format(sample_size),int(task_num))
评论列表
文章目录