run_final.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:QDREN 作者: andreamad8 项目源码 文件源码
def main(task_num,sample_size=''):
    embedding_size = 100
    epoch = 300
    best_accuracy = 0.0
    grind_ris={}

    if not os.path.exists('data/ris/task_{}'.format(task_num)):
        os.makedirs('data/ris/task_{}'.format(task_num))

    param_grid = {'nb': [20],
                  'lr': [0.001],
                  'tr': [[0,0,0,0]],
                  'L2': [0.001],# [0.0,0.1,0.01,0.001,0.0001]
                  'bz': [32],
                  'dr': [0.5],
                  }
    grid = list(ParameterGrid(param_grid))
    np.random.shuffle(grid)
    for params in list(grid):
        data = Dataset('data/tasks_1-20_v1-2/en-valid{}/'.format(sample_size),int(task_num))

        ## for sentence
        par = get_parameters(data,epoch,data._data['sent_len'],data._data['sent_numb'],embedding_size,params)
        t = train(epoch,params['bz'], data, par, dr=params['dr'], _test=True)

        acc = sorted([v for k,v in t[5].items()])[-1]

        if (acc > best_accuracy):
            best_accuracy = acc

        grind_ris[str(params)] = acc

        f_save = 'data/ris/task_{}/{}.PIK'.format(task_num,str(params)+str(acc))
        with open(f_save, 'w') as f:
            pickle.dump((t), f)



    # batch_size = 32
    # epoch = 200
    # if not os.path.exists('data/ris/task_{}'.format(task_num)):
    #     os.makedirs('data/ris/task_{}'.format(task_num))
    # data = Dataset('data/tasks_1-20_v1-2/en-valid{}/'.format(sample_size),int(task_num))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号