cosLSTM.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:Question-Answering-NNs 作者: nbogdan 项目源码 文件源码
def train(self, train_data, validation_data, folder):
        context_data, question_data, answer_data, y_train = train_data
        context_data_v, question_data_v, answer_data_v, y_val = validation_data
        print("Model Fitting")
        filepath = folder + "structures/cos-lstm-nn" + VERSION + "-final-{epoch:02d}-{val_acc:.2f}.hdf5"

        checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=0, save_best_only=True, mode='max')
        model_json = self.model.to_json()
        with open(folder + "/structures/cos-lstm-model" + VERSION + ".json", "w") as json_file:
            json_file.write(model_json)
        self.model.summary()
        import numpy as np
        context_data = np.array(list(map(lambda x: x[:MAX_SEQUENCE_LENGTH_C], context_data)))
        context_data_v = np.array(list(map(lambda x: x[:MAX_SEQUENCE_LENGTH_C], context_data_v)))
        self.model.fit({'context': context_data, 'question': question_data, 'answer': answer_data}, y_train,
                  validation_data=({'context': context_data_v, 'question': question_data_v, 'answer': answer_data_v}, y_val),
                  epochs=50, batch_size=256, callbacks=[checkpoint], verbose=2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号