def train(self, dir_name: str, enable_checkpoints: bool = False) -> None:
"""
???????? ????.
:param dir_name: ?????, ? ??????? ???????????? ??? ?????? ??????.
:param enable_checkpoints: ???????????? ?? ?????????.
"""
# ?????????? ??????
x, y = self.__load_dict()
x, y = self.__prepare_data(x, y)
# ??????? ?? ???????.
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.1, random_state=42)
# ???????? ?????? ????????.
callbacks = [EarlyStopping(monitor='val_acc', patience=3)] # type: List[Callback]
if enable_checkpoints:
checkpoint_name = os.path.join(dir_name, "{epoch:02d}-{val_loss:.2f}.hdf5")
callbacks.append(ModelCheckpoint(checkpoint_name, monitor='val_loss'))
self.model.fit(x_train, y_train, verbose=1, epochs=200, validation_data=(x_val, y_val),
callbacks=callbacks, batch_size=self.batch_size)
# ??????? ???????? ?? val ???????.
accuracy = self.model.evaluate(x_val, y_val)[1]
# ?????? WER ?? ???? ???????.
wer = self.__evaluate_wer(x, y)[0]
# ???? ????? ???????? ?? ???? ????????.
self.model.fit(x, y, verbose=1, epochs=1, batch_size=self.batch_size)
# ?????????? ??????.
filename = "stress_{language}_{rnn}{units}_dropout{dropout}_acc{acc}_wer{wer}.h5"
filename = filename.format(language=self.language, rnn=self.rnn.__name__,
units=self.units, dropout=self.dropout, acc=int(accuracy * 100),
wer=int(wer * 100))
self.model.save(os.path.join(dir_name, filename))
评论列表
文章目录