lstm.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:lain 作者: llllllllll 项目源码 文件源码
def _sample_weights(self, aim_error, accuracy_error):
        """Sample weights based on the error.

        Parameters
        ----------
        aim_error : np.ndarray
            The aim errors for each sample.
        accuracy_error : np.ndarray
            The accuracy error errors for each sample.

        Returns
        -------
        weights : np.ndarray
            The weights for each sample.

        Notes
        -----
        This weighs samples based on their standard deviations above the mean
        with some clipping.
        """
        aim_zscore = (aim_error - aim_error.mean()) / aim_error.std()
        aim_weight = np.clip(aim_zscore, 1, 4) / 4

        accuracy_zscore = (
            accuracy_error - accuracy_error.mean()
        ) / accuracy_error.std()
        accuracy_weight = np.clip(accuracy_zscore, 1, 4) / 4

        return {
            'aim_error': aim_weight,
            'accuracy_error': accuracy_weight,
        }
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号