def main():
embedding_size = 100
epoch = 300
best_accuracy = 0.0
sent_numb,sent_len = None,None
grind_ris={}
param_grid = {'nb': [5],
'lr': [0.01,0.001,0.0001],
'tr': [[1,1,0,0]],
'L2': [0.001,0.0001],
'bz': [64],
'dr': [0.5],
'mw': [150],
'w' : [3,4,5],
'op': ['Adam']
}
grid = list(ParameterGrid(param_grid))
np.random.shuffle(grid)
for params in list(grid):
data = Dataset(train_size=10000,dev_size=None,test_size=None,sent_len=sent_len,
sent_numb=sent_numb, embedding_size=embedding_size,
max_windows=params['mw'],win=params['w'])
# ## for sentence
# # par = get_parameters(data,epoch,sent_len,sent_numb,embedding_size)
# ## for windows
par = get_parameters(data,epoch,(params['w']*2)+1,params['mw'],embedding_size,params)
t = train(epoch,params['bz'], data, par, dr=params['dr'], _test=False)
acc = sorted([v for k,v in t[3].items()])[-1]
if (acc > best_accuracy):
best_accuracy = acc
grind_ris[str(params)] = acc
f_save = 'checkpoints/CNN_WIND/{}.PIK'.format(str(params)+str(acc))
with open(f_save, 'w') as f:
pickle.dump((t), f)
评论列表
文章目录