model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:CIAN 作者: yanghanxy 项目源码 文件源码
def train_model(opt, logger):
    logger.info('---START---')
    # initialize for reproduce
    np.random.seed(opt.seed)

    # load data
    logger.info('---LOAD DATA---')
    opt, training, training_snli, validation, test_matched, test_mismatched = load_data(opt)

    if not opt.skip_train:
        logger.info('---TRAIN MODEL---')
        for train_counter in range(opt.max_epochs):
            if train_counter == 0:
                model = build_model(opt)
            else:
                model = load_model_local(opt)
            np.random.seed(train_counter)
            lens = len(training_snli[-1])
            perm = np.random.permutation(lens)
            idx = perm[:int(lens * 0.2)]
            train_data = [np.concatenate((training[0], training_snli[0][idx])),
                          np.concatenate((training[1], training_snli[1][idx])),
                          np.concatenate((training[2], training_snli[2][idx]))]
            csv_logger = CSVLogger('{}{}.csv'.format(opt.log_dir, opt.model_name), append=True)
            cp_filepath = opt.save_dir + "cp-" + opt.model_name + "-" + str(train_counter) + "-{val_acc:.2f}.h5"
            cp = ModelCheckpoint(cp_filepath, monitor='val_acc', save_best_only=True, save_weights_only=True)
            callbacks = [cp, csv_logger]
            model.fit(train_data[:-1], train_data[-1], batch_size=opt.batch_size, epochs=1, validation_data=(validation[:-1], validation[-1]), callbacks=callbacks)
            save_model_local(opt, model)
    else:
        logger.info('---LOAD MODEL---')
        model = load_model_local(opt)

    # predict
    logger.info('---TEST MODEL---')
    preds_matched = model.predict(test_matched[:-1], batch_size=128, verbose=1)
    preds_mismatched = model.predict(test_mismatched[:-1], batch_size=128, verbose=1)

    save_preds_matched_to_csv(preds_matched, test_mismatched[-1], opt)
    save_preds_mismatched_to_csv(preds_mismatched, test_mismatched[-1], opt)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号