def fit_zipped(self, dataset: Iterable[Tuple[T1, T2]], test_dataset: Opt[Iterable[Tuple[T1, T2]]]=None,
batch_size: Opt[int] = None,
max_epochs: int = 1, min_epochs: int = 1, criterion_window: int = 5,
max_training_time: Opt[float] = None,
batch_report_interval: Opt[int] = None, epoch_report_interval: Opt[int] = None):
"""For fitting to an iterable sequence of pairs, such as may arise in very large streaming datasets from sources
that don't fit the random access and known-length requirements of a torch.data.Dataset (e.g. a sequence of
sentences split from a set of text files as might arise in NLP applications.
Like TorchModel.fit(), this estimates input normalization before the weight update, and weight initialization of
the torch_module is up to you. Returns per-epoch losses and validation losses (if any).
This method handles packaging X and y into a batch iterator of the kind that torch modules expect."""
batch_size = batch_size or self.default_batch_size
if self.should_normalize:
sample, dataset = peek(dataset, self.norm_n_samples)
sample = [t[0] for t in sample]
if self.encode_input:
sample = [self.encode_input(x) for x in sample]
sample = stack(sample)
self.estimate_normalization(sample)
return self.update_zipped(dataset=dataset, test_dataset=test_dataset, batch_size=batch_size,
max_epochs=max_epochs, min_epochs=min_epochs,
criterion_window=criterion_window,
max_training_time=max_training_time,
batch_report_interval=batch_report_interval, epoch_report_interval=epoch_report_interval)
评论列表
文章目录