train.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:Predicting-First-Impressions 作者: mel-2445 项目源码 文件源码
def train(Xtrain, ytrain, Xtrain_norm, ytrain_norm, Xvalidate, yvalidate, space):
    import sys
    from keras.optimizers import RMSprop
    from keras.callbacks import Callback

    class CorrelationEarlyStopping(Callback):
        def __init__(self, monitor='validate', patience=0, delta=.001):
            """
            :param monitor: 'validate' or 'train'
            :param patience: how many epochs to wait
            :param delta: by how much the monitored value has to be greater than the last maximum
            """
            self.rvalues = {'train': [], 'validate': []}
            self.monitor = monitor  # validate, train
            self.patience = patience
            self.delta = delta
            self.wait = 0
            self.best = 0
            self.num_epochs = 0
            self.best_model = None

        def on_epoch_end(self, epoch, logs={}):
            r2 = get_metrics(self.model, x=Xtrain_norm, y=ytrain_norm)
            self.rvalues['train'].append(r2)
            r2 = get_metrics(self.model, x=Xvalidate, y=yvalidate)
            self.rvalues['validate'].append(r2)
            print ('\n\tTrain r2: {}\n\tValidate r2: {}\n'.format(self.rvalues['train'][-1], self.rvalues['validate'][-1]))
            sys.stdout.flush()

            if self.rvalues[self.monitor][-1] - self.delta >= self.best:
                self.best = self.rvalues[self.monitor][-1]
                self.wait = 0
                self.num_epochs = epoch
                self.best_model = self.model
            else:
                if self.wait >= self.patience:
                    self.num_epochs = epoch - self.patience
                    self.model.stop_training = True
                else:
                    self.num_epochs = epoch
                    self.wait += 1

    model = vgg_variant(space)
    lr = 10**(-space['learning_rate'])
    rmsprop = RMSprop(lr=lr, rho=0.9, epsilon=1e-08)
    model.compile(loss='mean_squared_error', optimizer=rmsprop)
    monitor = CorrelationEarlyStopping(monitor='validate', patience=6, delta=0.01)
    gen = data_generator(Xtrain, ytrain, batch_size=space['batch_size'], space=space,
                         weighted_sampling=space['weighted_sampling'], augment=space['augment'],
                         sampling_factor=space['sampling_factor'], sampling_intercept=space['sampling_intercept'])
    model.fit_generator(gen, space['samples_per_epoch'], 50, 1, [monitor], (Xvalidate, yvalidate))
    return monitor.best_model, monitor.rvalues
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号