network.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:tsnet 作者: coxlab 项目源码 文件源码
def fit(self, dataset, settings):

        X_trn, y_trn, X_val, y_val, X_tst, y_tst = dataset

        y_trn = np_utils.to_categorical(y_trn, 10 if settings.dataset != 'cifar100' else 100)
        y_val = np_utils.to_categorical(y_val, 10 if settings.dataset != 'cifar100' else 100)
        y_tst = np_utils.to_categorical(y_tst, 10 if settings.dataset != 'cifar100' else 100) if len(y_tst) > 0 else []

        settings.lrnparam = (settings.lrnparam[:1] + settings.lrnparam[2:])

        self.model.compile(loss='categorical_crossentropy', optimizer=eval(settings.lrnalg)(*settings.lrnparam), metrics=["accuracy"])

        class PerEpochTest(Callback):

            def on_epoch_begin(self, epoch, logs={}): self.tic = time.time()
            def on_epoch_end  (self, epoch, logs={}):

                self.model.history.history['time']  = [] if 'time' not in self.model.history.history else self.model.history.history['time']
                self.model.history.history['time'] += [time.time() - self.tic]

                self.model.history.history['tst_acc']  = [] if 'tst_acc' not in self.model.history.history else self.model.history.history['tst_acc']
                self.model.history.history['tst_acc'] += [self.model.evaluate(X_tst, y_tst, batch_size=settings.batchsize, verbose=0)[1]]

        aug = augment(settings.dataset) if settings.augment else None
        arg = {'nb_epoch':settings.epoch, 'validation_data':(X_val, y_val), 'callbacks':[PerEpochTest()] if len(y_tst) > 0 else [], 'verbose':settings.verbose}

        if aug is None: self.model.fit          (         X_trn, y_trn, batch_size=settings.batchsize,                                                               **arg)
        else          : self.model.fit_generator(aug.flow(X_trn, y_trn, batch_size=settings.batchsize), samples_per_epoch=len(X_trn), nb_worker=4, pickle_safe=True, **arg)

        return self.model.history.history
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号