phoneme_rnn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:rupo 作者: IlyaGusev 项目源码 文件源码
def train(self, dir_name: str, enable_checkpoints: bool = False) -> None:
        """
        ???????? ????.

        :param dir_name: ?????, ? ??????? ???????????? ??? ?????? ??????.
        :param enable_checkpoints: ???????????? ?? ?????????.
        """
        # ?????????? ??????
        x, y = self.__load_dict()
        x, y = self.__prepare_data(x, y)
        # ??????? ?? ???????.
        x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.1, random_state=42)
        # ???????? ?????? ????????.
        callbacks = [EarlyStopping(monitor='val_acc', patience=3)]  # type: List[Callback]
        if enable_checkpoints:
            checkpoint_name = os.path.join(dir_name, "{epoch:02d}-{val_loss:.2f}.hdf5")
            callbacks.append(ModelCheckpoint(checkpoint_name, monitor='val_loss'))
        self.model.fit(x_train, y_train, verbose=1, epochs=200, validation_data=(x_val, y_val),
                       callbacks=callbacks, batch_size=self.batch_size)
        # ??????? ???????? ?? val ???????.
        accuracy = self.model.evaluate(x_val, y_val)[1]
        # ?????? WER ?? ???? ???????.
        wer = self.__evaluate_wer(x, y)[0]
        # ???? ????? ???????? ?? ???? ????????.
        self.model.fit(x, y, verbose=1, epochs=1, batch_size=self.batch_size)
        # ?????????? ??????.
        filename = "stress_{language}_{rnn}{units}_dropout{dropout}_acc{acc}_wer{wer}.h5"
        filename = filename.format(language=self.language, rnn=self.rnn.__name__,
                                   units=self.units, dropout=self.dropout, acc=int(accuracy * 100),
                                   wer=int(wer * 100))
        self.model.save(os.path.join(dir_name, filename))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号