def create_typed_numpy_ndarray(
dims: int, data_type: t.ClassVar, required_shape: t.Optional[t.Sequence[int]] = None):
"""Create a statically typed version of numpy.ndarray."""
def typed_ndarray(*args, **kwargs):
"""Create an instance of numpy.ndarray which must conform to declared type constraints."""
shape_loc = (args, 0) if len(args) > 0 else (kwargs, 'shape')
dtype_loc = (args, 1) if len(args) > 1 else (kwargs, 'dtype')
shape = shape_loc[0][shape_loc[1]]
if shape is not None and (dims != 1 if isinstance(shape, int) else len(shape) != dims):
raise ValueError(
'actual ndarray shape {} conflicts with its declared dimensionality of {}'
.format(shape, dims))
if required_shape is not None:
if any((req_dim is not Ellipsis and dim != req_dim)
for dim, req_dim in zip(shape, required_shape)):
raise ValueError('actual ndarray shape {} conflicts with its required shape of {}'
.format(shape, required_shape))
try:
dtype = dtype_loc[0][dtype_loc[1]]
except KeyError:
dtype = None
if dtype is not None and dtype is not data_type:
raise TypeError('actual ndarray dtype {} conflicts with its declared dtype {}'
.format(dtype, data_type))
dtype_loc[0][dtype_loc[1]] = data_type
# print('np.ndarray', args, kwargs)
return np.ndarray(*args, **kwargs)
return typed_ndarray