def process_map_data(path, return_full=False):
data = joblib.load(path)
im_data = data['im']
value_data = data['value']
state_data = data['state']
if return_full:
im_full = np.concatenate((np.expand_dims(im_data, 1),
np.expand_dims(value_data, 1)),
axis=1).astype(dtype=np.uint8)
return im_full, state_data, data['label'], data['sample_idx']
label_data = np.array([np.eye(1, 8, l)[0] for l in data['label']])
num = im_data.shape[0]
num_train = num - num / 5
im_train = np.concatenate((np.expand_dims(im_data[:num_train], 1),
np.expand_dims(value_data[:num_train], 1)),
axis=1).astype(dtype=np.float32)
state_train = state_data[:num_train]
label_train = label_data[:num_train]
im_test = np.concatenate((np.expand_dims(im_data[num_train:], 1),
np.expand_dims(value_data[num_train:], 1)),
axis=1).astype(dtype=np.float32)
state_test = state_data[num_train:]
label_test = label_data[num_train:]
return (im_train, state_train, label_train), \
(im_test, state_test, label_test), data['sample_idx']
utils.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录