def padding3D(input, width_mode, pad_factor):
if width_mode == 'multiple':
assert isinstance(pad_factor, int)
shape = input.shape[-3:]
added_shape = [(0,0)]*len(input.shape[:-3])
for dim in shape:
added_shape.append((0,dim % pad_factor))
output = np.pad(input, tuple(added_shape), 'constant', constant_values=(0, 0))
elif width_mode == 'fixed':
assert isinstance(pad_factor,list) or isinstance(pad_factor,tuple)
output = np.pad(input, tuple(pad_factor), 'constant',constant_values=(0, 0))
elif width_mode == 'match':
assert isinstance(pad_factor, list) or isinstance(pad_factor, tuple)
shape = input.shape[-3:]
shape_difference = np.asarray(pad_factor) - np.asarray(shape)
added_shape = [(0, 0)] * len(input.shape[:-3])
subs_shape = [np.s_[:]]* len(input.shape[:-3])
for diff in shape_difference:
if diff < 0:
subs_shape.append(np.s_[:diff])
added_shape.append((0, 0))
else:
subs_shape.append(np.s_[:])
added_shape.append((0, diff))
output = np.pad(input, tuple(added_shape), 'constant', constant_values=(0, 0))
output = output[subs_shape]
else:
raise ValueError("Padding3D error (src.helpers.preprocessing_utils): No existen padding method " + str(width_mode))
return output
评论列表
文章目录