run.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:QDREN 作者: andreamad8 项目源码 文件源码
def main():
    embedding_size = 100
    epoch = 300
    best_accuracy = 0.0
    sent_numb,sent_len = None,None
    grind_ris={}

    param_grid = {'nb': [5],
                  'lr': [0.01,0.001,0.0001],
                  'tr': [[1,1,0,0]],
                  'L2': [0.001,0.0001],
                  'bz': [64],
                  'dr': [0.5],
                  'mw': [150],
                  'w' : [3,4,5],
                  'op': ['Adam']
                  }
    grid = list(ParameterGrid(param_grid))
    np.random.shuffle(grid)
    for params in list(grid):

        data = Dataset(train_size=10000,dev_size=None,test_size=None,sent_len=sent_len,
                        sent_numb=sent_numb, embedding_size=embedding_size,
                        max_windows=params['mw'],win=params['w'])

        # ## for sentence
        # # par = get_parameters(data,epoch,sent_len,sent_numb,embedding_size)
        # ## for windows
        par = get_parameters(data,epoch,(params['w']*2)+1,params['mw'],embedding_size,params)

        t = train(epoch,params['bz'], data, par, dr=params['dr'], _test=False)

        acc = sorted([v for k,v in t[3].items()])[-1]

        if (acc > best_accuracy):
            best_accuracy = acc

        grind_ris[str(params)] = acc

        f_save = 'checkpoints/CNN_WIND/{}.PIK'.format(str(params)+str(acc))
        with open(f_save, 'w') as f:
            pickle.dump((t), f)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号