models.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:keras-gp 作者: alshedivat 项目源码 文件源码
def predict(self, X, X_tr=None, Y_tr=None,
                batch_size=32, return_var=False, verbose=0):
        """Generate output predictions for the input samples batch by batch.

        Arguments:
        ----------
            X : np.ndarray or list of np.ndarrays
            batch_size : uint (default: 128)
            return_var : bool (default: False)
                Whether predictive variance is returned.
            verbose : uint (default: 0)
                Verbosity mode, 0 or 1.

        Returns:
        --------
            preds : a list or a tuple of lists
                Lists of output predictions and variance estimates.
        """
        # Update GP data if provided (and grid if necessary)
        if X_tr is not None and Y_tr is not None:
            X_tr, Y_tr, _ = self._standardize_user_data(
                X_tr, Y_tr,
                sample_weight=None,
                class_weight=None,
                check_batch_axis=False,
                batch_size=batch_size)
            H_tr = self.transform(X_tr, batch_size=batch_size)
            for gp, h, y in zip(self.output_gp_layers, H_tr, Y_tr):
                gp.backend.update_data('tr', h, y)
                if gp.update_grid:
                    gp.backend.update_grid('tr')

        # Validate user data
        X = _standardize_input_data(
            X, self.input_names, self.internal_input_shapes,
            check_batch_axis=False)

        H = self.transform(X, batch_size=batch_size)

        preds = []
        for gp, h in zip(self.output_gp_layers, H):
            preds.append(gp.backend.predict(h, return_var=return_var))

        if return_var:
            preds = map(list, zip(*preds))

        return preds


# Apply tweaks
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号