interface.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:sk-torch 作者: mattHawthorn 项目源码 文件源码
def fit_zipped(self, dataset: Iterable[Tuple[T1, T2]], test_dataset: Opt[Iterable[Tuple[T1, T2]]]=None,
                   batch_size: Opt[int] = None,
                   max_epochs: int = 1, min_epochs: int = 1, criterion_window: int = 5,
                   max_training_time: Opt[float] = None,
                   batch_report_interval: Opt[int] = None, epoch_report_interval: Opt[int] = None):
        """For fitting to an iterable sequence of pairs, such as may arise in very large streaming datasets from sources
        that don't fit the random access and known-length requirements of a torch.data.Dataset (e.g. a sequence of
        sentences split from a set of text files as might arise in NLP applications.
        Like TorchModel.fit(), this estimates input normalization before the weight update, and weight initialization of
        the torch_module is up to you. Returns per-epoch losses and validation losses (if any).
        This method handles packaging X and y into a batch iterator of the kind that torch modules expect."""
        batch_size = batch_size or self.default_batch_size
        if self.should_normalize:
            sample, dataset = peek(dataset, self.norm_n_samples)
            sample = [t[0] for t in sample]
            if self.encode_input:
                sample = [self.encode_input(x) for x in sample]
            sample = stack(sample)
            self.estimate_normalization(sample)

        return self.update_zipped(dataset=dataset, test_dataset=test_dataset, batch_size=batch_size,
                                  max_epochs=max_epochs, min_epochs=min_epochs,
                                  criterion_window=criterion_window,
                                  max_training_time=max_training_time,
                                  batch_report_interval=batch_report_interval, epoch_report_interval=epoch_report_interval)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号