test_basic.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def get_numeric_types(with_int=True, with_float=True, with_complex=False,
                      only_theano_types=True):
    """
    Return numpy numeric data types.

    :param with_int: Whether to include integer types.

    :param with_float: Whether to include floating point types.

    :param with_complex: Whether to include complex types.

    :param only_theano_types: If True, then numpy numeric data types that are
    not supported by Theano are ignored (i.e. those that are not declared in
    scalar/basic.py).

    :returns: A list of unique data type objects. Note that multiple data types
    may share the same string representation, but can be differentiated through
    their `num` attribute.

    Note that when `only_theano_types` is True we could simply return the list
    of types defined in the `scalar` module. However with this function we can
    test more unique dtype objects, and in the future we may use it to
    automatically detect new data types introduced in numpy.
    """
    if only_theano_types:
        theano_types = [d.dtype for d in theano.scalar.all_types]
    rval = []

    def is_within(cls1, cls2):
        # Return True if scalars defined from `cls1` are within the hierarchy
        # starting from `cls2`.
        # The third test below is to catch for instance the fact that
        # one can use ``dtype=numpy.number`` and obtain a float64 scalar, even
        # though `numpy.number` is not under `numpy.floating` in the class
        # hierarchy.
        return (cls1 is cls2 or
                issubclass(cls1, cls2) or
                isinstance(numpy.array([0], dtype=cls1)[0], cls2))

    for cls in get_numeric_subclasses():
        dtype = numpy.dtype(cls)
        if ((not with_complex and is_within(cls, numpy.complexfloating)) or
            (not with_int and is_within(cls, numpy.integer)) or
            (not with_float and is_within(cls, numpy.floating)) or
            (only_theano_types and dtype not in theano_types)):
            # Ignore this class.
            continue
        rval.append([str(dtype), dtype, dtype.num])
    # We sort it to be deterministic, then remove the string and num elements.
    return [x[1] for x in sorted(rval, key=str)]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号