def fit(self, X: Iterable[T1], y: Iterable[T2],
X_test: Opt[Iterable[T1]]=None, y_test: Opt[Iterable[T2]]=None,
batch_size: Opt[int]=None, shuffle: bool=False,
max_epochs: int=1, min_epochs: int=1, criterion_window: int=5,
max_training_time: Opt[float]=None,
batch_report_interval: Opt[int]=None, epoch_report_interval: Opt[int]=None):
"""This method fits the *entire* pipeline, including input normalization. Initialization of weight/bias
parameters in the torch_module is up to you; there is no obvious canonical way to do it here.
Returns per-epoch losses and validation losses (if any)."""
batch_size = batch_size or self.default_batch_size
if self.should_normalize:
sample, X = peek(X, self.norm_n_samples)
if self.encode_input:
sample = [self.encode_input(x) for x in sample]
sample = stack(sample)
self.estimate_normalization(sample)
return self.update(X=X, y=y, X_test=X_test, y_test=y_test, batch_size=batch_size, shuffle=shuffle,
max_epochs=max_epochs, min_epochs=min_epochs,
criterion_window=criterion_window,
max_training_time=max_training_time,
batch_report_interval=batch_report_interval, epoch_report_interval=epoch_report_interval)
评论列表
文章目录