def image_preprocess(obs, resize_width, resize_height, to_gray):
"""Applies basic preprocessing for image observations.
Args:
obs (numpy.ndarray): 2-D or 3-D uint8 type image.
resize_width (int): Resize width. To disable resize, pass None.
resize_height (int): Resize height. To disable resize, pass None.
to_gray (bool): Converts image to grayscale.
Returns (numpy.ndarray):
Processed 3-D float type image.
"""
processed_obs = np.squeeze(obs)
if to_gray:
processed_obs = cv2.cvtColor(processed_obs, cv2.COLOR_RGB2GRAY)
if resize_height and resize_width:
processed_obs = cv2.resize(processed_obs, (resize_height, resize_width))
if np.ndim(processed_obs) == 2:
processed_obs = np.expand_dims(processed_obs, 2)
return processed_obs
评论列表
文章目录