def crop_and_concat(x1, x2):
x1_shape = x1.get_shape().as_list()
x2_shape = x2.get_shape().as_list()
offsets = [0, (x1_shape[1] - x2_shape[1]) // 2, (x1_shape[2] - x2_shape[2]) // 2, (x1_shape[3] - x2_shape[3]) // 2, 0]
size = [-1, x2_shape[1], x2_shape[2], x2_shape[3], -1]
x1_crop = tf.slice(x1, offsets, size)
return tf.concat([x1_crop, x2], 4)
# Some code from https://github.com/shiba24/3d-unet.git
评论列表
文章目录