train_utils.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:cancer 作者: yancz1989 项目源码 文件源码
def interp(w, i, channel_dim):
  '''
  Input:
    w: A 4D block tensor of shape (n, h, w, c)
    i: A list of 3-tuples [(x_1, y_1, z_1), (x_2, y_2, z_2), ...],
      each having type (int, float, float)

    The 4D block represents a batch of 3D image feature volumes with c channels.
    The input i is a list of points  to index into w via interpolation. Direct
    indexing is not possible due to y_1 and z_1 being float values.
  Output:
    A list of the values: [
      w[x_1, y_1, z_1, :]
      w[x_2, y_2, z_2, :]
      ...
      w[x_k, y_k, z_k, :]
    ]
    of the same length == len(i)
  '''
  w_as_vector = tf.reshape(w, [-1, channel_dim]) # gather expects w to be 1-d
  upper_l = tf.to_int32(tf.concat(axis=1, values=[i[:, 0:1], tf.floor(i[:, 1:2]), tf.floor(i[:, 2:3])]))
  upper_r = tf.to_int32(tf.concat(axis=1, values=[i[:, 0:1], tf.floor(i[:, 1:2]), tf.ceil(i[:, 2:3])]))
  lower_l = tf.to_int32(tf.concat(axis=1, values=[i[:, 0:1], tf.ceil(i[:, 1:2]), tf.floor(i[:, 2:3])]))
  lower_r = tf.to_int32(tf.concat(axis=1, values=[i[:, 0:1], tf.ceil(i[:, 1:2]), tf.ceil(i[:, 2:3])]))

  upper_l_idx = to_idx(upper_l, tf.shape(w))
  upper_r_idx = to_idx(upper_r, tf.shape(w))
  lower_l_idx = to_idx(lower_l, tf.shape(w))
  lower_r_idx = to_idx(lower_r, tf.shape(w))

  upper_l_value = tf.gather(w_as_vector, upper_l_idx)
  upper_r_value = tf.gather(w_as_vector, upper_r_idx)
  lower_l_value = tf.gather(w_as_vector, lower_l_idx)
  lower_r_value = tf.gather(w_as_vector, lower_r_idx)

  alpha_lr = tf.expand_dims(i[:, 2] - tf.floor(i[:, 2]), 1)
  alpha_ud = tf.expand_dims(i[:, 1] - tf.floor(i[:, 1]), 1)

  upper_value = (1 - alpha_lr) * upper_l_value + (alpha_lr) * upper_r_value
  lower_value = (1 - alpha_lr) * lower_l_value + (alpha_lr) * lower_r_value
  value = (1 - alpha_ud) * upper_value + (alpha_ud) * lower_value
  return value
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号