def fit(self, dataset, settings):
X_trn, y_trn, X_val, y_val, X_tst, y_tst = dataset
y_trn = np_utils.to_categorical(y_trn, 10 if settings.dataset != 'cifar100' else 100)
y_val = np_utils.to_categorical(y_val, 10 if settings.dataset != 'cifar100' else 100)
y_tst = np_utils.to_categorical(y_tst, 10 if settings.dataset != 'cifar100' else 100) if len(y_tst) > 0 else []
settings.lrnparam = (settings.lrnparam[:1] + settings.lrnparam[2:])
self.model.compile(loss='categorical_crossentropy', optimizer=eval(settings.lrnalg)(*settings.lrnparam), metrics=["accuracy"])
class PerEpochTest(Callback):
def on_epoch_begin(self, epoch, logs={}): self.tic = time.time()
def on_epoch_end (self, epoch, logs={}):
self.model.history.history['time'] = [] if 'time' not in self.model.history.history else self.model.history.history['time']
self.model.history.history['time'] += [time.time() - self.tic]
self.model.history.history['tst_acc'] = [] if 'tst_acc' not in self.model.history.history else self.model.history.history['tst_acc']
self.model.history.history['tst_acc'] += [self.model.evaluate(X_tst, y_tst, batch_size=settings.batchsize, verbose=0)[1]]
aug = augment(settings.dataset) if settings.augment else None
arg = {'nb_epoch':settings.epoch, 'validation_data':(X_val, y_val), 'callbacks':[PerEpochTest()] if len(y_tst) > 0 else [], 'verbose':settings.verbose}
if aug is None: self.model.fit ( X_trn, y_trn, batch_size=settings.batchsize, **arg)
else : self.model.fit_generator(aug.flow(X_trn, y_trn, batch_size=settings.batchsize), samples_per_epoch=len(X_trn), nb_worker=4, pickle_safe=True, **arg)
return self.model.history.history
评论列表
文章目录