def generate_batch(im):
"""
preprocess image, return batch
:param im: cv2.imread returns [height, width, channel] in BGR
:return:
data_batch: MXNet input batch
data_names: names in data_batch
im_scale: float number
"""
import ipdb
ipdb.set_trace()
im_array, im_scale = resize(im, SHORT_SIDE, LONG_SIDE, stride=config.IMAGE_STRIDE)
im_array = transform(im_array, PIXEL_MEANS)
im_info = np.array([[im_array.shape[2], im_array.shape[3], im_scale]], dtype=np.float32)
data = [mx.nd.array(im_array), mx.nd.array(im_info)]
data_shapes = [('data', im_array.shape), ('im_info', im_info.shape)]
data_batch = mx.io.DataBatch(data=data, label=None, provide_data=data_shapes, provide_label=None)
return data_batch, DATA_NAMES, im_scale
评论列表
文章目录