utils.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:mobula 作者: wkcn 项目源码 文件源码
def check_dtype(dtype, func_str, a, n):
        if np.isscalar(a) or not a.shape:
            if func_str not in ("sum", "prod", "len"):
                raise ValueError("scalar inputs are supported only for 'sum', "
                                 "'prod' and 'len'")
            a_dtype = np.dtype(type(a))
        else:
            a_dtype = a.dtype

        if dtype is not None:
            # dtype set by the user
            # Careful here: np.bool != np.bool_ !
            if np.issubdtype(dtype, np.bool_) and \
                    not('all' in func_str or 'any' in func_str):
                raise TypeError("function %s requires a more complex datatype "
                                "than bool" % func_str)
            if not np.issubdtype(dtype, np.integer) and func_str in ('len', 'nanlen'):
                raise TypeError("function %s requires an integer datatype" % func_str)
            # TODO: Maybe have some more checks here
            return np.dtype(dtype)
        else:
            try:
                return np.dtype(_forced_types[func_str])
            except KeyError:
                if func_str in _forced_float_types:
                    if np.issubdtype(a_dtype, np.floating):
                        return a_dtype
                    else:
                        return np.dtype(np.float64)
                else:
                    if func_str == 'sum':
                        # Try to guess the minimally required int size
                        if np.issubdtype(a_dtype, np.int64):
                            # It's not getting bigger anymore
                            # TODO: strictly speaking it might need float
                            return np.dtype(np.int64)
                        elif np.issubdtype(a_dtype, np.integer):
                            maxval = np.iinfo(a_dtype).max * n
                            return minimum_dtype(maxval, a_dtype)
                        elif np.issubdtype(a_dtype, np.bool_):
                            return minimum_dtype(n, a_dtype)
                        else:
                            # floating, inexact, whatever
                            return a_dtype
                    elif func_str in _forced_same_type:
                        return a_dtype
                    else:
                        if isinstance(a_dtype, np.integer):
                            return np.dtype(np.int64)
                        else:
                            return a_dtype
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号