python类CSVLogger()的实例源码

trainer.py 文件源码 项目:multi-gpu-keras-tf 作者: sallamander 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def _get_callbacks(self):
        """Return callbacks to pass into the Model.fit method

        Note: This simply returns statically instantiated callbacks. In the
        future it could be altered to allow for callbacks that are specified
        and configured via a training config.
        """

        fpath_history = os.path.join(self.output_dir, 'history.csv')
        fpath_weights = os.path.join(self.output_dir, 'weights.h5')

        csv_logger = CSVLogger(filename=fpath_history)
        model_checkpoint = ModelCheckpoint(
            filepath=fpath_weights, verbose=True
        )
        callbacks = [csv_logger, model_checkpoint]

        return callbacks
load_deepmodels.py 文件源码 项目:Youtube8mdataset_kagglechallenge 作者: jasonlee27 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def train(self, model, saveto_path=''):
        x_train, y_train = get_data(self.train_data_path, "train", "frame", self.feature_type)
        print('%d training frame level samples.' % len(x_train))
        x_valid, y_valid = get_data(self.valid_data_path, "valid", "frame", self.feature_type)
        print('%d validation frame level samples.' % len(x_valid))

        sgd = SGD(lr=0.01,
                  decay=1e-6,
                  momentum=0.9,
                  nesterov=True)
        model.compile(loss='categorical_crossentropy',
                      optimizer=sgd,
                      metrics=['accuracy'])

        callbacks = list()
        callbacks.append(CSVLogger(LOG_FILE))
        callbacks.append(ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, min_lr=0.0001))

        if saveto_path:
            callbacks.append(ModelCheckpoint(filepath=MODEL_WEIGHTS, verbose=1))

        model.fit(x_train,
                  y_train,
                  epochs=5,
                  callbacks=callbacks,
                  validation_data=(x_valid, y_valid))

        # Save the weights on completion.
        if saveto_path:
            model.save_weights(saveto_path)
inception_flowers_tune.py 文件源码 项目:keras-surgeon 作者: BenWhetton 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def train_top_model():
    # Load the bottleneck features and labels
    train_features = np.load(open(output_dir+'bottleneck_features_train.npy', 'rb'))
    train_labels = np.load(open(output_dir+'bottleneck_labels_train.npy', 'rb'))
    validation_features = np.load(open(output_dir+'bottleneck_features_validation.npy', 'rb'))
    validation_labels = np.load(open(output_dir+'bottleneck_labels_validation.npy', 'rb'))

    # Create the top model for the inception V3 network, a single Dense layer
    # with softmax activation.
    top_input = Input(shape=train_features.shape[1:])
    top_output = Dense(5, activation='softmax')(top_input)
    model = Model(top_input, top_output)

    # Train the model using the bottleneck features and save the weights.
    model.compile(optimizer=SGD(lr=1e-4, momentum=0.9),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    csv_logger = CSVLogger(output_dir + 'top_model_training.csv')
    model.fit(train_features, train_labels,
              epochs=top_epochs,
              batch_size=batch_size,
              validation_data=(validation_features, validation_labels),
              callbacks=[csv_logger])
    model.save_weights(top_model_weights_path)
keras_spell.py 文件源码 项目:DeepSpell_temp 作者: surmenok 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def iterate_training(model, dataset, initial_epoch):
    """Iterative Training"""

    checkpoint = ModelCheckpoint(MODEL_CHECKPOINT_DIRECTORYNAME + '/' + MODEL_CHECKPOINT_FILENAME,
                                 save_best_only=True)
    tensorboard = TensorBoard()
    csv_logger = CSVLogger(CSV_LOG_FILENAME)

    X_dev_batch, y_dev_batch = next(dataset.dev_set_batch_generator(1000))
    show_samples_callback = LambdaCallback(
        on_epoch_end=lambda epoch, logs: show_samples(model, dataset, epoch, logs, X_dev_batch, y_dev_batch))

    train_batch_generator = dataset.train_set_batch_generator(BATCH_SIZE)
    validation_batch_generator = dataset.dev_set_batch_generator(BATCH_SIZE)

    model.fit_generator(train_batch_generator,
                        samples_per_epoch=SAMPLES_PER_EPOCH,
                        nb_epoch=NUMBER_OF_EPOCHS,
                        validation_data=validation_batch_generator,
                        nb_val_samples=SAMPLES_PER_EPOCH,
                        callbacks=[checkpoint, tensorboard, csv_logger, show_samples_callback],
                        verbose=1,
                        initial_epoch=initial_epoch)
Config.py 文件源码 项目:NetworkCompress 作者: luzai 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def set_logger_path(self, name):
        self.csv_logger = CSVLogger(osp.join(self.output_path, name))
DEC.py 文件源码 项目:DEC-keras 作者: XifengGuo 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def pretrain(self, x, y=None, optimizer='adam', epochs=200, batch_size=256, save_dir='results/temp'):
        print('...Pretraining...')
        self.autoencoder.compile(optimizer=optimizer, loss='mse')

        csv_logger = callbacks.CSVLogger(save_dir + '/pretrain_log.csv')
        cb = [csv_logger]
        if y is not None:
            class PrintACC(callbacks.Callback):
                def __init__(self, x, y):
                    self.x = x
                    self.y = y
                    super(PrintACC, self).__init__()

                def on_epoch_end(self, epoch, logs=None):
                    if epoch % int(epochs/10) != 0:
                        return
                    feature_model = Model(self.model.input,
                                          self.model.get_layer(
                                              'encoder_%d' % (int(len(self.model.layers) / 2) - 1)).output)
                    features = feature_model.predict(self.x)
                    km = KMeans(n_clusters=len(np.unique(self.y)), n_init=20, n_jobs=4)
                    y_pred = km.fit_predict(features)
                    # print()
                    print(' '*8 + '|==>  acc: %.4f,  nmi: %.4f  <==|'
                          % (metrics.acc(self.y, y_pred), metrics.nmi(self.y, y_pred)))

            cb.append(PrintACC(x, y))

        # begin pretraining
        t0 = time()
        self.autoencoder.fit(x, x, batch_size=batch_size, epochs=epochs, callbacks=cb)
        print('Pretraining time: ', time() - t0)
        self.autoencoder.save_weights(save_dir + '/ae_weights.h5')
        print('Pretrained weights are saved to %s/ae_weights.h5' % save_dir)
        self.pretrained = True
yt8m_frame_model2.py 文件源码 项目:Youtube8mdataset_kagglechallenge 作者: jasonlee27 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def train(self, model, saveto_path=''):
        x_train, y_train = get_data(self.train_data_path, "train", "frame", self.feature_type)
        print('%d training frame level samples.' % len(x_train))
        x_valid, y_valid = get_data(self.valid_data_path, "valid", "frame", self.feature_type)
        print('%d validation frame level samples.' % len(x_valid))

        sgd = SGD(lr=0.001,
                  decay=1e-6,
                  momentum=0.9,
                  nesterov=True)
        model.compile(loss='binary_crossentropy',
                      optimizer=sgd,
                      metrics=['accuracy'])

        callbacks = list()
        callbacks.append(CSVLogger(LOG_FILE))
        callbacks.append(ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, min_lr=0.0001))

        if saveto_path:
            callbacks.append(ModelCheckpoint(filepath=saveto_path, verbose=1))

        model.fit(x_train,
                  y_train,
                  nb_epoch=5,
                  callbacks=callbacks,
                  validation_data=(x_valid, y_valid))

        # Save the weights on completion.
        if saveto_path:
            model.save_weights(saveto_path)
yt8m_video_model.py 文件源码 项目:Youtube8mdataset_kagglechallenge 作者: jasonlee27 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def train(self, model, saveto_path=''):
        x_train, y_train = get_data(self.train_data_path, "train", "video", self.feature_type)
        print('%d training video level samples.' % len(x_train))
        x_valid, y_valid = get_data(self.valid_data_path, "valid", "video", self.feature_type)
        print('%d validation video level samples.' % len(x_valid))

        sgd = SGD(lr=0.001,
                  decay=1e-6,
                  momentum=0.9,
                  nesterov=True)
        model.compile(loss='categorical_crossentropy',
                      optimizer=sgd,
                      metrics=['accuracy'])

        callbacks = list()
        callbacks.append(CSVLogger(LOG_FILE))
        callbacks.append(ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, min_lr=0.0001))

        if saveto_path:
            callbacks.append(ModelCheckpoint(filepath=VID_MODEL_WEIGHTS, verbose=1))

        model.fit(x_train,
                  y_train,
                  epochs=5,
                  callbacks=callbacks,
                  validation_data=(x_valid, y_valid))

        # Save the weights on completion.
        if saveto_path:
            model.save_weights(saveto_path)
yt8m_frame_model.py 文件源码 项目:Youtube8mdataset_kagglechallenge 作者: jasonlee27 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def train(self, model, saveto_path=''):
        x_train, y_train = get_data(self.train_data_path, "train", "frame", self.feature_type)
        print('%d training frame level samples.' % len(x_train))
        x_valid, y_valid = get_data(self.valid_data_path, "valid", "frame", self.feature_type)
        print('%d validation frame level samples.' % len(x_valid))

        sgd = SGD(lr=0.001,
                  decay=1e-6,
                  momentum=0.9,
                  nesterov=True)
        model.compile(loss='binary_crossentropy',
                      optimizer=sgd,
                      metrics=['accuracy'])

        callbacks = list()
        callbacks.append(CSVLogger(LOG_FILE))
        callbacks.append(ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, min_lr=0.0001))

        if saveto_path:
            callbacks.append(ModelCheckpoint(filepath=saveto_path, verbose=1))

        model.fit(x_train,
                  y_train,
                  nb_epoch=5,
                  callbacks=callbacks,
                  validation_data=(x_valid, y_valid))

        # Save the weights on completion.
        if saveto_path:
            model.save_weights(saveto_path)
load_deepmodels.py 文件源码 项目:Youtube8mdataset_kagglechallenge 作者: jasonlee27 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def train(self, model, saveto_path=''):
        x_train, y_train = get_data(self.train_data_path, "train", "video", self.feature_type)
        print('%d training video level samples.' % len(x_train))
        x_valid, y_valid = get_data(self.valid_data_path, "valid", "video", self.feature_type)
        print('%d validation video level samples.' % len(x_valid))

        sgd = SGD(lr=0.01,
                  decay=1e-6,
                  momentum=0.9,
                  nesterov=True)
        model.compile(loss='categorical_crossentropy',
                      optimizer=sgd,
                      metrics=['accuracy'])

        callbacks = list()
        callbacks.append(CSVLogger(LOG_FILE))
        callbacks.append(ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, min_lr=0.0001))

        if saveto_path:
            callbacks.append(ModelCheckpoint(filepath=MODEL_WEIGHTS, verbose=1))

        model.fit(x_train,
                  y_train,
                  epochs=5,
                  callbacks=callbacks,
                  validation_data=(x_valid, y_valid))

        # Save the weights on completion.
        if saveto_path:
            model.save_weights(saveto_path)
create.py 文件源码 项目:segmenty 作者: paulfitz 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def __enter__(self):
        chk = ModelCheckpoint(self.checkpoint, verbose=0, save_best_only=False,
                              save_weights_only=False, mode='auto')
        csv_logger = CSVLogger('training.log')
        snaps = LambdaCallback(on_epoch_end=lambda epoch, logs: self.snap(epoch))
        return [chk, csv_logger, snaps]
gan.py 文件源码 项目:Keras-GAN-Animeface-Character 作者: forcecore 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def train_gan( dataf ) :
    gen, disc, gan = build_networks()

    # Uncomment these, if you want to continue training from some snapshot.
    # (or load pretrained generator weights)
    #load_weights(gen, Args.genw)
    #load_weights(disc, Args.discw)

    logger = CSVLogger('loss.csv') # yeah, you can use callbacks independently
    logger.on_train_begin() # initialize csv file
    with h5py.File( dataf, 'r' ) as f :
        faces = f.get( 'faces' )
        run_batches(gen, disc, gan, faces, logger, range(5000))
    logger.on_train_end()
train_cnn.py 文件源码 项目:keras-anime-face-recognition 作者: namakemono 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def run():
    (X_train, y_train), (X_test, y_test) = datasets.load_data(img_rows=32, img_cols=32)
    Y_train = np_utils.to_categorical(y_train, nb_classes)
    Y_test = np_utils.to_categorical(y_test, nb_classes)
    model = CNN(input_shape=X_train.shape[1:], nb_classes=nb_classes)
    model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
    X_train = preprocess_input(X_train)
    X_test = preprocess_input(X_test)
    csv_logger = CSVLogger('../log/cnn.log')
    checkpointer = ModelCheckpoint(filepath="/tmp/weights.hdf5", monitor="val_acc", verbose=1, save_best_only=True)
    datagen = ImageDataGenerator(
        featurewise_center=False,  # set input mean to 0 over the dataset
        samplewise_center=False,  # set each sample mean to 0
        featurewise_std_normalization=False,  # divide inputs by std of the dataset
        samplewise_std_normalization=False,  # divide each input by its std
        zca_whitening=False,  # apply ZCA whitening
        rotation_range=0,  # randomly rotate images in the range (degrees, 0 to 180)
        width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
        height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)
        horizontal_flip=True,  # randomly flip images
        vertical_flip=False)  # randomly flip images
    datagen.fit(X_train)
    model.fit_generator(datagen.flow(X_train, Y_train,
                                     batch_size=batch_size),
                        samples_per_epoch=X_train.shape[0],
                        nb_epoch=nb_epoch,
                        validation_data=(X_test, Y_test), 
                        callbacks=[csv_logger, checkpointer])
neuralnet_node_residual.py 文件源码 项目:skp_edu_docker 作者: TensorMSA 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def get_model_resnet(self):
        try :
            keras.backend.tensorflow_backend.clear_session()
            self.lr_reducer = ReduceLROnPlateau(monitor='val_loss', factor=np.sqrt(0.1), cooldown=0, patience=5, min_lr=0.5e-6)
            self.early_stopper = EarlyStopping(monitor='val_acc', min_delta=0.001, patience=10)
            self.csv_logger = CSVLogger('resnet.csv')
            num_classes = self.netconf["config"]["num_classes"]
            numoutputs = self.netconf["config"]["layeroutputs"]
            x_size = self.dataconf["preprocess"]["x_size"]
            y_size = self.dataconf["preprocess"]["y_size"]
            channel = self.dataconf["preprocess"]["channel"]
            optimizer = self.netconf["config"]["optimizer"]

            filelist = os.listdir(self.model_path)
            filelist.sort(reverse=True)
            last_chk_path = self.model_path + "/" + self.load_batch+self.file_end

            try:
                self.model = keras.models.load_model(last_chk_path)
                logging.info("Train Restored checkpoint from:" + last_chk_path)
            except Exception as e:
                if numoutputs == 18:
                    self.model = resnet.ResnetBuilder.build_resnet_18((channel, x_size, y_size), num_classes)
                elif numoutputs == 34:
                    self.model = resnet.ResnetBuilder.build_resnet_34((channel, x_size, y_size), num_classes)
                elif numoutputs == 50:
                    self.model = resnet.ResnetBuilder.build_resnet_50((channel, x_size, y_size), num_classes)
                elif numoutputs == 101:
                    self.model = resnet.ResnetBuilder.build_resnet_101((channel, x_size, y_size), num_classes)
                elif numoutputs == 152:
                    self.model = resnet.ResnetBuilder.build_resnet_152((channel, x_size, y_size), num_classes)
                elif numoutputs == 200:
                    self.model = resnet.ResnetBuilder.build_resnet_200((channel, x_size, y_size), num_classes)
                logging.info("None to restore checkpoint. Initializing variables instead." + last_chk_path)
                logging.info(e)

            self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        except Exception as e :
            logging.error("===Error on Residualnet build model : {0}".format(e))

    ####################################################################################################################
model.py 文件源码 项目:CIAN 作者: yanghanxy 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def train_model(opt, logger):
    logger.info('---START---')
    # initialize for reproduce
    np.random.seed(opt.seed)

    # load data
    logger.info('---LOAD DATA---')
    opt, training, training_snli, validation, test_matched, test_mismatched = load_data(opt)

    if not opt.skip_train:
        logger.info('---TRAIN MODEL---')
        for train_counter in range(opt.max_epochs):
            if train_counter == 0:
                model = build_model(opt)
            else:
                model = load_model_local(opt)
            np.random.seed(train_counter)
            lens = len(training_snli[-1])
            perm = np.random.permutation(lens)
            idx = perm[:int(lens * 0.2)]
            train_data = [np.concatenate((training[0], training_snli[0][idx])),
                          np.concatenate((training[1], training_snli[1][idx])),
                          np.concatenate((training[2], training_snli[2][idx]))]
            csv_logger = CSVLogger('{}{}.csv'.format(opt.log_dir, opt.model_name), append=True)
            cp_filepath = opt.save_dir + "cp-" + opt.model_name + "-" + str(train_counter) + "-{val_acc:.2f}.h5"
            cp = ModelCheckpoint(cp_filepath, monitor='val_acc', save_best_only=True, save_weights_only=True)
            callbacks = [cp, csv_logger]
            model.fit(train_data[:-1], train_data[-1], batch_size=opt.batch_size, epochs=1, validation_data=(validation[:-1], validation[-1]), callbacks=callbacks)
            save_model_local(opt, model)
    else:
        logger.info('---LOAD MODEL---')
        model = load_model_local(opt)

    # predict
    logger.info('---TEST MODEL---')
    preds_matched = model.predict(test_matched[:-1], batch_size=128, verbose=1)
    preds_mismatched = model.predict(test_mismatched[:-1], batch_size=128, verbose=1)

    save_preds_matched_to_csv(preds_matched, test_mismatched[-1], opt)
    save_preds_mismatched_to_csv(preds_mismatched, test_mismatched[-1], opt)
ockre.py 文件源码 项目:OCkRE 作者: rossumai 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def train(self, run_name, start_epoch, stop_epoch, verbose=1, epochlen=2048, vallen=2000):

        #Kind of dummy iterators, they would be passed from outside, along with content of separate
        #Training and validation real data.
        train_crop_iter = CropImageIterator()
        val_crop_iter = CropImageIterator()

        words_per_epoch = epochlen
        val_words = len(val_crop_iter)
        img_gen = DataGenerator(minibatch_size=32, img_w=self.img_w, img_h=self.img_h, downsample_factor=(self.pool_size ** 2),
                                train_crop_iter=train_crop_iter,
                                val_crop_iter=val_crop_iter,
                                absolute_max_string_len=self.absolute_max_string_len,
                                train_realratio=1.0,
                                val_realratio=1.0
                                )
        if vallen:
            val_words = vallen

        adam = keras.optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        output_dir = os.path.join(OUTPUT_DIR, run_name)

        # the loss calc occurs elsewhere, so use a dummy lambda func for the loss
        self.model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=adam)
        if start_epoch > 0:
            weight_file = os.path.join(OUTPUT_DIR, os.path.join(run_name, 'weights%02d.h5' % (start_epoch - 1)))
            self.model.load_weights(weight_file)

        viz_cb = VizCallback(run_name, self.test_func, img_gen.next_val(), self.model, val_words)

        weights_best_fname = os.path.join(output_dir, '%s-weights-best_loss.h5' % run_name)
        weights_best_fnamev = os.path.join(output_dir, '%s-weights-best_val_loss.h5' % run_name)
        weights_best_fnamemned = os.path.join(output_dir, '%s-weights-best_mned.h5' % run_name)
        weights_best_cro_accu = os.path.join(output_dir, '%s-weights-best_crop_accu.h5' % run_name)

        csv_logger = CSVLogger(os.path.join(output_dir, '%s.training.log' % run_name))

        checkpointer_loss = ModelCheckpoint(weights_best_fname, monitor='loss', save_best_only=True, save_weights_only=False, mode='min')
        checkpointer_vloss = ModelCheckpoint(weights_best_fnamev, monitor='val_loss', save_best_only=True, save_weights_only=False, mode='min')
        checkpointer_mned = ModelCheckpoint(weights_best_fnamemned, monitor='mean_norm_ed', save_best_only=True, save_weights_only=False, mode='min')
        checkpointer_accu = ModelCheckpoint(weights_best_cro_accu, monitor='crop_accuracy', save_best_only=True, save_weights_only=False, mode='max')

        self.model.fit_generator(generator=img_gen.next_train(), samples_per_epoch=words_per_epoch,
                                 nb_epoch=stop_epoch, validation_data=img_gen.next_val(), nb_val_samples=val_words,
                                 callbacks=[viz_cb, img_gen, checkpointer_loss, checkpointer_vloss, checkpointer_mned, checkpointer_accu, csv_logger],
                                 initial_epoch=start_epoch, verbose=verbose)
test_callbacks.py 文件源码 项目:keras 作者: NVIDIA 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def test_CSVLogger():
    filepath = 'log.tsv'
    sep = '\t'
    (X_train, y_train), (X_test, y_test) = get_test_data(nb_train=train_samples,
                                                         nb_test=test_samples,
                                                         input_shape=(input_dim,),
                                                         classification=True,
                                                         nb_class=nb_class)
    y_test = np_utils.to_categorical(y_test)
    y_train = np_utils.to_categorical(y_train)

    def make_model():
        np.random.seed(1337)
        model = Sequential()
        model.add(Dense(nb_hidden, input_dim=input_dim, activation='relu'))
        model.add(Dense(nb_class, activation='softmax'))

        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizers.SGD(lr=0.1),
                      metrics=['accuracy'])
        return model

    # case 1, create new file with defined separator
    model = make_model()
    cbks = [callbacks.CSVLogger(filepath, separator=sep)]
    model.fit(X_train, y_train, batch_size=batch_size,
              validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=1)

    assert os.path.exists(filepath)
    with open(filepath) as csvfile:
        dialect = Sniffer().sniff(csvfile.read())
    assert dialect.delimiter == sep
    del model
    del cbks

    # case 2, append data to existing file, skip header
    model = make_model()
    cbks = [callbacks.CSVLogger(filepath, separator=sep, append=True)]
    model.fit(X_train, y_train, batch_size=batch_size,
              validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=1)

    # case 3, reuse of CSVLogger object
    model.fit(X_train, y_train, batch_size=batch_size,
              validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=1)

    import re
    with open(filepath) as csvfile:
        output = " ".join(csvfile.readlines())
        assert len(re.findall('epoch', output)) == 1

    os.remove(filepath)
experiment.py 文件源码 项目:srcnn 作者: qobilidop 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def train(self, train_set='91-image', val_set='Set5', epochs=1,
              resume=True):
        # Load and process data
        x_train, y_train = self.load_set(train_set)
        x_val, y_val = self.load_set(val_set)
        x_train, x_val = [self.pre_process(x)
                          for x in [x_train, x_val]]
        y_train, y_val = [self.inverse_post_process(y)
                          for y in [y_train, y_val]]

        # Compile model
        model = self.compile(self.build_model(x_train))
        model.summary()

        # Save model architecture
        # Currently in Keras 2 it's not possible to load a model with custom
        # layers. So we just save it without checking consistency.
        self.config_file.write_text(model.to_yaml())

        # Inherit weights
        if resume:
            latest_epoch = self.latest_epoch
            if latest_epoch > -1:
                weights_file = self.weights_file(epoch=latest_epoch)
                model.load_weights(str(weights_file))
            initial_epoch = latest_epoch + 1
        else:
            initial_epoch = 0

        # Set up callbacks
        callbacks = []
        callbacks += [ModelCheckpoint(str(self.model_file))]
        callbacks += [ModelCheckpoint(str(self.weights_file()),
                                      save_weights_only=True)]
        callbacks += [CSVLogger(str(self.history_file), append=resume)]

        # Train
        model.fit(x_train, y_train, epochs=epochs, callbacks=callbacks,
                  validation_data=(x_val, y_val), initial_epoch=initial_epoch)

        # Plot metrics history
        prefix = str(self.history_file).rsplit('.', maxsplit=1)[0]
        df = pd.read_csv(str(self.history_file))
        epoch = df['epoch']
        for metric in ['Loss', 'PSNR']:
            train = df[metric.lower()]
            val = df['val_' + metric.lower()]
            plt.figure()
            plt.plot(epoch, train, label='train')
            plt.plot(epoch, val, label='val')
            plt.legend(loc='best')
            plt.xlabel('Epoch')
            plt.ylabel(metric)
            plt.savefig('.'.join([prefix, metric.lower(), 'png']))
            plt.close()


问题


面经


文章

微信
公众号

扫码关注公众号