rcnn_merge_lstm.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:mtl 作者: zhenhongChen 项目源码 文件源码
def rcnn_mtl(processed_datasets, index_embedding, params):
    start = datetime.datetime.now()

    x_trains, y_trains, x_tests, y_tests = processed_datasets
    mtl_model, single_models = build_models(params, index_embedding)
    print(mtl_model.summary())
    # plot_model(mtl_model, to_file='mtl_model.png', show_shapes=True)

    itera = 0
    batch_input = {}
    batch_output = {}
    batch_size = params['batch_size']
    iterations = params['iterations']
    sys.stdout.write('\ntotal iterations: {}'.format(iterations))

    while (itera < iterations):
        generate_batch_data(batch_input, batch_output, batch_size, x_trains, y_trains)
        mtl_model.train_on_batch(batch_input, batch_output)

        itera += 1
        if (itera > 200 and itera % 100 == 0):
            sys.stdout.write('\n\ncurrent iteration: {}'.format(itera))
            # evaluate(single_models, x_trains, y_trains, 'train')
            evaluate(single_models, x_tests, y_tests, 'test')
            sys.stdout.flush()

            if (itera >= 500):
                save_predictions(single_models, x_tests, params['prediction_path'])
                # save_models(single_models, params['save_model_path'])

    end = datetime.datetime.now()
    sys.stdout.write('\nused time: {}\n'.format(end - start))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号