def interp(w, i, channel_dim):
'''
Input:
w: A 4D block tensor of shape (n, h, w, c)
i: A list of 3-tuples [(x_1, y_1, z_1), (x_2, y_2, z_2), ...],
each having type (int, float, float)
The 4D block represents a batch of 3D image feature volumes with c channels.
The input i is a list of points to index into w via interpolation. Direct
indexing is not possible due to y_1 and z_1 being float values.
Output:
A list of the values: [
w[x_1, y_1, z_1, :]
w[x_2, y_2, z_2, :]
...
w[x_k, y_k, z_k, :]
]
of the same length == len(i)
'''
w_as_vector = tf.reshape(w, [-1, channel_dim]) # gather expects w to be 1-d
upper_l = tf.to_int32(tf_concat(1, [i[:, 0:1], tf.floor(i[:, 1:2]), tf.floor(i[:, 2:3])]))
upper_r = tf.to_int32(tf_concat(1, [i[:, 0:1], tf.floor(i[:, 1:2]), tf.ceil(i[:, 2:3])]))
lower_l = tf.to_int32(tf_concat(1, [i[:, 0:1], tf.ceil(i[:, 1:2]), tf.floor(i[:, 2:3])]))
lower_r = tf.to_int32(tf_concat(1, [i[:, 0:1], tf.ceil(i[:, 1:2]), tf.ceil(i[:, 2:3])]))
upper_l_idx = to_idx(upper_l, tf.shape(w))
upper_r_idx = to_idx(upper_r, tf.shape(w))
lower_l_idx = to_idx(lower_l, tf.shape(w))
lower_r_idx = to_idx(lower_r, tf.shape(w))
upper_l_value = tf.gather(w_as_vector, upper_l_idx)
upper_r_value = tf.gather(w_as_vector, upper_r_idx)
lower_l_value = tf.gather(w_as_vector, lower_l_idx)
lower_r_value = tf.gather(w_as_vector, lower_r_idx)
alpha_lr = tf.expand_dims(i[:, 2] - tf.floor(i[:, 2]), 1)
alpha_ud = tf.expand_dims(i[:, 1] - tf.floor(i[:, 1]), 1)
upper_value = (1 - alpha_lr) * upper_l_value + (alpha_lr) * upper_r_value
lower_value = (1 - alpha_lr) * lower_l_value + (alpha_lr) * lower_r_value
value = (1 - alpha_ud) * upper_value + (alpha_ud) * lower_value
return value
评论列表
文章目录