interface.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:sk-torch 作者: mattHawthorn 项目源码 文件源码
def fit(self, X: Iterable[T1], y: Iterable[T2],
            X_test: Opt[Iterable[T1]]=None, y_test: Opt[Iterable[T2]]=None,
            batch_size: Opt[int]=None, shuffle: bool=False,
            max_epochs: int=1, min_epochs: int=1, criterion_window: int=5,
            max_training_time: Opt[float]=None,
            batch_report_interval: Opt[int]=None, epoch_report_interval: Opt[int]=None):
        """This method fits the *entire* pipeline, including input normalization. Initialization of weight/bias
        parameters in the torch_module is up to you; there is no obvious canonical way to do it here.
        Returns per-epoch losses and validation losses (if any)."""
        batch_size = batch_size or self.default_batch_size
        if self.should_normalize:
            sample, X = peek(X, self.norm_n_samples)
            if self.encode_input:
                sample = [self.encode_input(x) for x in sample]
            sample = stack(sample)
            self.estimate_normalization(sample)

        return self.update(X=X, y=y, X_test=X_test, y_test=y_test, batch_size=batch_size, shuffle=shuffle,
                           max_epochs=max_epochs, min_epochs=min_epochs,
                           criterion_window=criterion_window,
                           max_training_time=max_training_time,
                           batch_report_interval=batch_report_interval, epoch_report_interval=epoch_report_interval)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号