def train(self, train_data, valid_data, batch_size, epochs, opt_name, lr, grad_clip):
"""
?????
"""
def save_weight_on_epoch_end(epoch, e_logs):
filename = "{}weight-epoch{}-{}-{}.h5".format(self.weight_path,
time.strftime("%Y-%m-%d-(%H-%M)", time.localtime()),
epoch,
e_logs['val_acc'])
self.model.save_weights(filepath=filename)
checkpointer = LambdaCallback(on_epoch_end=save_weight_on_epoch_end)
# tensorboard = TensorBoard(log_dir="./logs", histogram_freq=1, write_images=True)
earlystopping = EarlyStopping(monitor="val_loss", patience=3, verbose=1)
# ????????
questions_ok, documents_ok, context_mask, candidates_ok, y_true = self.preprocess_input_sequences(train_data)
v_questions, v_documents, v_context_mask, v_candidates, v_y_true = self.preprocess_input_sequences(valid_data)
if opt_name == "SGD":
optimizer = SGD(lr=lr, decay=1e-6, clipvalue=grad_clip)
elif opt_name == "ADAM":
optimizer = Adam(lr=lr, clipvalue=grad_clip)
else:
raise NotImplementedError("Other Optimizer Not Implemented.-_-||")
self.model.compile(optimizer=optimizer,
loss="categorical_crossentropy",
metrics=["accuracy"])
# ?????????
self.load_weight()
data = {"q_input": questions_ok,
"d_input": documents_ok,
"context_mask": context_mask,
"candidates_bi": candidates_ok}
v_data = {"q_input": v_questions,
"d_input": v_documents,
"context_mask": v_context_mask,
"candidates_bi": v_candidates}
logs = self.model.fit(x=data,
y=y_true,
batch_size=batch_size,
epochs=epochs,
validation_data=(v_data, v_y_true),
callbacks=[checkpointer, earlystopping])
attention_sum_reader.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录