def train_model(opt, logger):
logger.info('---START---')
# initialize for reproduce
np.random.seed(opt.seed)
# load data
logger.info('---LOAD DATA---')
opt, training, training_snli, validation, test_matched, test_mismatched = load_data(opt)
if not opt.skip_train:
logger.info('---TRAIN MODEL---')
for train_counter in range(opt.max_epochs):
if train_counter == 0:
model = build_model(opt)
else:
model = load_model_local(opt)
np.random.seed(train_counter)
lens = len(training_snli[-1])
perm = np.random.permutation(lens)
idx = perm[:int(lens * 0.2)]
train_data = [np.concatenate((training[0], training_snli[0][idx])),
np.concatenate((training[1], training_snli[1][idx])),
np.concatenate((training[2], training_snli[2][idx]))]
csv_logger = CSVLogger('{}{}.csv'.format(opt.log_dir, opt.model_name), append=True)
cp_filepath = opt.save_dir + "cp-" + opt.model_name + "-" + str(train_counter) + "-{val_acc:.2f}.h5"
cp = ModelCheckpoint(cp_filepath, monitor='val_acc', save_best_only=True, save_weights_only=True)
callbacks = [cp, csv_logger]
model.fit(train_data[:-1], train_data[-1], batch_size=opt.batch_size, epochs=1, validation_data=(validation[:-1], validation[-1]), callbacks=callbacks)
save_model_local(opt, model)
else:
logger.info('---LOAD MODEL---')
model = load_model_local(opt)
# predict
logger.info('---TEST MODEL---')
preds_matched = model.predict(test_matched[:-1], batch_size=128, verbose=1)
preds_mismatched = model.predict(test_mismatched[:-1], batch_size=128, verbose=1)
save_preds_matched_to_csv(preds_matched, test_mismatched[-1], opt)
save_preds_mismatched_to_csv(preds_mismatched, test_mismatched[-1], opt)
评论列表
文章目录