def numpy2bifrost(dtype):
if dtype == np.int8: return _bf.BF_DTYPE_I8
elif dtype == np.int16: return _bf.BF_DTYPE_I16
elif dtype == np.int32: return _bf.BF_DTYPE_I32
elif dtype == np.uint8: return _bf.BF_DTYPE_U8
elif dtype == np.uint16: return _bf.BF_DTYPE_U16
elif dtype == np.uint32: return _bf.BF_DTYPE_U32
elif dtype == np.float16: return _bf.BF_DTYPE_F16
elif dtype == np.float32: return _bf.BF_DTYPE_F32
elif dtype == np.float64: return _bf.BF_DTYPE_F64
elif dtype == np.float128: return _bf.BF_DTYPE_F128
elif dtype == ci8: return _bf.BF_DTYPE_CI8
elif dtype == ci16: return _bf.BF_DTYPE_CI16
elif dtype == ci32: return _bf.BF_DTYPE_CI32
elif dtype == cf16: return _bf.BF_DTYPE_CF16
elif dtype == np.complex64: return _bf.BF_DTYPE_CF32
elif dtype == np.complex128: return _bf.BF_DTYPE_CF64
elif dtype == np.complex256: return _bf.BF_DTYPE_CF128
else: raise ValueError("Unsupported dtype: " + str(dtype))
评论列表
文章目录