framework.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:lang2program 作者: kelvinguu 项目源码 文件源码
def train(self, train_batches, valid_batches, samples_per_epoch, nb_epoch, nb_val_samples, extra_callbacks=None):
        """Train the model.

        Automatically adds the following Keras callbacks:
            - ModelCheckpoint
            - EarlyStopping
            - TensorBoard

        Args:
            train_batches (Iterable[Batch]): an iterable of training Batches
            valid_batches (Iterable[Batch]): an iterable of validation Batches
            samples_per_epoch (int)
            nb_epoch (int): max number of epochs to train for
            nb_val_samples (int): number of samples for validation
            extra_callbacks (list): a list of additional Keras callbacks to run
        """
        checkpoint_path = join(self.checkpoint_dir, 'weights.{epoch:02d}-{val_loss:.2f}.hdf5')
        checkpointer = ModelCheckpoint(checkpoint_path, verbose=1, save_best_only=False)
        early_stopper = EarlyStopping(monitor='val_loss', patience=2, verbose=1)
        tboard = TensorBoard(self.tensorboard_dir, write_graph=False)

        callbacks = [checkpointer, early_stopper, tboard]
        if extra_callbacks:
            callbacks.extend(extra_callbacks)

        train = self._vectorized_batches(train_batches)
        valid = self._vectorized_batches(valid_batches)

        self.keras_model.fit_generator(train, samples_per_epoch, nb_epoch,
                                       callbacks=callbacks,
                                       validation_data=valid, nb_val_samples=nb_val_samples
                                       )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号