def preprocess_images(images):
if images.shape[0] < 4:
# single image
x_t = images[0]
x_t = imresize(x_t, (80, 80))
x_t = x_t.astype("float")
x_t /= 255.0
s_t = np.stack((x_t, x_t, x_t, x_t), axis=2)
else:
# 4 images
xt_list = []
for i in range(images.shape[0]):
x_t = imresize(images[i], (80, 80))
x_t = x_t.astype("float")
x_t /= 255.0
xt_list.append(x_t)
s_t = np.stack((xt_list[0], xt_list[1], xt_list[2], xt_list[3]),
axis=2)
s_t = np.expand_dims(s_t, axis=0)
return s_t
rl-network-train.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录