def valid_generator():
while True:
for start in range(0, len(ids_valid_split), batch_size):
x_batch = []
y_batch = []
end = min(start + batch_size, len(ids_valid_split))
ids_valid_batch = ids_valid_split[start:end]
for id in ids_valid_batch.values:
img = cv2.imread('input/train/{}.jpg'.format(id))
img = cv2.resize(img, (input_size, input_size))
mask = cv2.imread('input/train_masks/{}_mask.png'.format(id), cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (input_size, input_size))
mask = np.expand_dims(mask, axis=2)
x_batch.append(img)
y_batch.append(mask)
x_batch = np.array(x_batch, np.float32) / 255
y_batch = np.array(y_batch, np.float32) / 255
yield x_batch, y_batch
train.py 文件源码
python
阅读 36
收藏 0
点赞 0
评论 0
评论列表
文章目录