def start_train(self):
""" Starts to Train the entire Model Based on set Parameters """
# 1, Prep
callback = [EarlyStopping(patience=self.Patience),
ReduceLROnPlateau(patience=5, verbose=1),
CSVLogger(filename=self.rnn_type+'log.csv'),
ModelCheckpoint(self.rnn_type + '_' + self.dataset + '.check',
save_best_only=True,
save_weights_only=True)]
# 2, Train
self.model.fit(x = [self.train[0],self.train[1]],
y = self.train[2],
batch_size = self.BatchSize,
epochs = self.MaxEpoch,
validation_data=([self.test[0], self.test[1]], self.test[2]),
callbacks = callback)
# 3, Evaluate
self.model.load_weights(self.rnn_type + '_' + self.dataset + '.check') # revert to the best model
self.evaluate_on_test()
评论列表
文章目录