tfRNN.py 文件源码

python
阅读 110 收藏 0 点赞 0 评论 0

项目:SNLI-Keras 作者: adamzjk 项目源码 文件源码
def start_train(self):
    """ Starts to Train the entire Model Based on set Parameters """
    # 1, Prep
    callback = [EarlyStopping(patience=self.Patience),
                ReduceLROnPlateau(patience=5, verbose=1),
                CSVLogger(filename=self.rnn_type+'log.csv'),
                ModelCheckpoint(self.rnn_type + '_' + self.dataset + '.check',
                                save_best_only=True,
                                save_weights_only=True)]

    # 2, Train
    self.model.fit(x = [self.train[0],self.train[1]],
                   y = self.train[2],
                   batch_size = self.BatchSize,
                   epochs = self.MaxEpoch,
                   validation_data=([self.test[0], self.test[1]], self.test[2]),
                   callbacks = callback)

    # 3, Evaluate
    self.model.load_weights(self.rnn_type + '_' + self.dataset + '.check') # revert to the best model
    self.evaluate_on_test()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号