attention_sum_reader.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:attention-sum-reader 作者: cairoHy 项目源码 文件源码
def train(self, train_data, valid_data, batch_size, epochs, opt_name, lr, grad_clip):
        """
        ?????
        """

        def save_weight_on_epoch_end(epoch, e_logs):
            filename = "{}weight-epoch{}-{}-{}.h5".format(self.weight_path,
                                                          time.strftime("%Y-%m-%d-(%H-%M)", time.localtime()),
                                                          epoch,
                                                          e_logs['val_acc'])
            self.model.save_weights(filepath=filename)

        checkpointer = LambdaCallback(on_epoch_end=save_weight_on_epoch_end)

        # tensorboard = TensorBoard(log_dir="./logs", histogram_freq=1, write_images=True)
        earlystopping = EarlyStopping(monitor="val_loss", patience=3, verbose=1)

        # ????????
        questions_ok, documents_ok, context_mask, candidates_ok, y_true = self.preprocess_input_sequences(train_data)
        v_questions, v_documents, v_context_mask, v_candidates, v_y_true = self.preprocess_input_sequences(valid_data)
        if opt_name == "SGD":
            optimizer = SGD(lr=lr, decay=1e-6, clipvalue=grad_clip)
        elif opt_name == "ADAM":
            optimizer = Adam(lr=lr, clipvalue=grad_clip)
        else:
            raise NotImplementedError("Other Optimizer Not Implemented.-_-||")
        self.model.compile(optimizer=optimizer,
                           loss="categorical_crossentropy",
                           metrics=["accuracy"])

        # ?????????
        self.load_weight()

        data = {"q_input": questions_ok,
                "d_input": documents_ok,
                "context_mask": context_mask,
                "candidates_bi": candidates_ok}
        v_data = {"q_input": v_questions,
                  "d_input": v_documents,
                  "context_mask": v_context_mask,
                  "candidates_bi": v_candidates}
        logs = self.model.fit(x=data,
                              y=y_true,
                              batch_size=batch_size,
                              epochs=epochs,
                              validation_data=(v_data, v_y_true),
                              callbacks=[checkpointer, earlystopping])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号