main.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:ethnicity-tensorflow 作者: jhyuklee 项目源码 文件源码
def main(_):
    # Save default params and set scope
    saved_params = FLAGS.__flags
    if saved_params['ensemble']:
        model_name = 'ensemble'
    elif saved_params['ngram'] == 1:
        model_name = 'unigram'
    elif saved_params['ngram'] == 2:
        model_name = 'bigram'
    elif saved_params['ngram'] == 3:
        model_name = 'trigram'
    else:
        assert True, 'Not supported ngram %d'% saved_params['ngram']
    model_name += '_embedding' if saved_params['embed'] else '_no_embedding' 
    saved_params['model_name'] = '%s' % model_name
    saved_params['checkpoint_dir'] += model_name
    pprint.PrettyPrinter().pprint(saved_params)
    saved_dataset = get_data(saved_params) 

    validation_writer = open(saved_params['valid_result_path'], 'a')
    validation_writer.write(model_name + "\n")
    validation_writer.write("[dim_hidden, dim_rnn_cell, learning_rate, lstm_dropout, lstm_layer, hidden_dropout, dim_embed]\n")
    validation_writer.write("combination\ttop1\ttop5\tepoch\n")

    # Run the model
    for _ in range(saved_params['valid_iteration']):
        # Sample parameter sets
        params, combination = sample_parameters(saved_params.copy())
        dataset = saved_dataset[:]

        # Initialize embeddings
        uni_init = get_char2vec(dataset[0][0][:], params['dim_embed_unigram'], dataset[3][0])
        bi_init = get_char2vec(dataset[0][1][:], params['dim_embed_bigram'], dataset[3][4])
        tri_init = get_char2vec(dataset[0][2][:], params['dim_embed_trigram'], dataset[3][5])

        print(model_name, 'Parameter sets: ', end='')
        pprint.PrettyPrinter().pprint(combination)

        rnn_model = RNN(params, [uni_init, bi_init, tri_init])
        top1, top5, ep = experiment(rnn_model, dataset, params)

        validation_writer.write(str(combination) + '\t')
        validation_writer.write(str(top1) + '\t' + str(top5) + '\tEp:' + str(ep) + '\n')

    validation_writer.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号