model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:zhihu_kanshan_cup_2017 作者: coderSkyChen 项目源码 文件源码
def trainmodel(self, isalldata):

        self.buildmodel_rcnn4_att_titledsp()

        import time
        cur_time = time.strftime('%Y-%m-%d-%H-%M', time.localtime(time.time()))

        checkpointer = ModelCheckpoint(filepath=self.savedir + "/" + cur_time + "_model-{epoch:02d}.hdf5", period=1)
        zhihuMetrics = ZHIHUMetrics()

        if isalldata:
            self.model.fit([self.titlechar_array, self.titleword_array, self.dspchar_array, self.dspword_array],
                           self.y,
                           epochs=self.num_epochs, batch_size=self.batch_size, verbose=1,
                           callbacks=[checkpointer])
        else:#with 9:1 validation
            self.model.fit([self.titlechar_array, self.titleword_array, self.dspchar_array, self.dspword_array],
                           self.y,
                           validation_split=0.1,
                           epochs=self.num_epochs, batch_size=self.batch_size, verbose=1,
                           callbacks=[checkpointer, zhihuMetrics])
        self.save_model()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号