load_deepmodels.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Youtube8mdataset_kagglechallenge 作者: jasonlee27 项目源码 文件源码
def train(self, model, saveto_path=''):
        x_train, y_train = get_data(self.train_data_path, "train", "frame", self.feature_type)
        print('%d training frame level samples.' % len(x_train))
        x_valid, y_valid = get_data(self.valid_data_path, "valid", "frame", self.feature_type)
        print('%d validation frame level samples.' % len(x_valid))

        sgd = SGD(lr=0.01,
                  decay=1e-6,
                  momentum=0.9,
                  nesterov=True)
        model.compile(loss='categorical_crossentropy',
                      optimizer=sgd,
                      metrics=['accuracy'])

        callbacks = list()
        callbacks.append(CSVLogger(LOG_FILE))
        callbacks.append(ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, min_lr=0.0001))

        if saveto_path:
            callbacks.append(ModelCheckpoint(filepath=MODEL_WEIGHTS, verbose=1))

        model.fit(x_train,
                  y_train,
                  epochs=5,
                  callbacks=callbacks,
                  validation_data=(x_valid, y_valid))

        # Save the weights on completion.
        if saveto_path:
            model.save_weights(saveto_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号