def get_flags(*types):
def get_dtype(t):
if isinstance(t, string_types):
return numpy.dtype(t)
elif isinstance(t, Type):
return t.dtype
elif isinstance(t, Variable):
return t.type.dtype
else:
raise TypeError("can't get a dtype from %s" % (type(t),))
dtypes = [get_dtype(t) for t in types]
flags = dict(cluda=True)
if any(d == numpy.float64 for d in dtypes):
flags['have_double'] = True
if any(d.itemsize < 4 for d in dtypes):
flags['have_small'] = True
if any(d.kind == 'c' for d in dtypes):
flags['have_complex'] = True
if any(d == numpy.float16 for d in dtypes):
flags['have_half'] = True
return flags
评论列表
文章目录