def _bilinear_interpolate(self,im, im_org, x, y):
with tf.variable_scope('_interpolate'):
# constants
x = tf.cast(x, 'float32')
y = tf.cast(y, 'float32')
height_f = tf.cast(self.height, 'float32')
width_f = tf.cast(self.width, 'float32')
zero = tf.zeros([], dtype='int32')
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
# scale indices from [-1, 1] to [0, width/height]
x = (x + 1.0)*(width_f) / 2.0
y = (y + 1.0)*(height_f) / 2.0
# do sampling
x0 = tf.cast(tf.floor(x), 'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), 'int32')
y1 = y0 + 1
x0 = tf.clip_by_value(x0, zero, max_x)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
dim2 = self.width
dim1 = self.width*self.height
base = self._repeat(tf.range(self.num_batch)*dim1, self.out_height*self.out_width, 'int32')
base_y0 = base + y0*dim2
base_y1 = base + y1*dim2
idx_a = tf.expand_dims(base_y0 + x0, 1)
idx_b = tf.expand_dims(base_y1 + x0, 1)
idx_c = tf.expand_dims(base_y0 + x1, 1)
idx_d = tf.expand_dims(base_y1 + x1, 1)
# use indices to lookup pixels in the flat image and restore
# channels dim
im_flat = tf.reshape(im, tf.stack([-1, self.num_channels]))
im_flat = tf.cast(im_flat, 'float32')
Ia = tf.scatter_nd(idx_a, im_flat, [self.num_batch*self.out_height*self.out_width, self.num_channels])
Ib = tf.scatter_nd(idx_b, im_flat, [self.num_batch*self.out_height*self.out_width, self.num_channels])
Ic = tf.scatter_nd(idx_c, im_flat, [self.num_batch*self.out_height*self.out_width, self.num_channels])
Id = tf.scatter_nd(idx_d, im_flat, [self.num_batch*self.out_height*self.out_width, self.num_channels])
x0_f = tf.cast(x0, 'float32')
x1_f = tf.cast(x1, 'float32')
y0_f = tf.cast(y0, 'float32')
y1_f = tf.cast(y1, 'float32')
wa = tf.scatter_nd(idx_a, tf.expand_dims(((x1_f-x) * (y1_f-y)), 1), [self.num_batch*self.out_height*self.out_width, 1])
wb = tf.scatter_nd(idx_b, tf.expand_dims(((x1_f-x) * (y-y0_f)), 1), [self.num_batch*self.out_height*self.out_width, 1])
wc = tf.scatter_nd(idx_c, tf.expand_dims(((x-x0_f) * (y1_f-y)), 1), [self.num_batch*self.out_height*self.out_width, 1])
wd = tf.scatter_nd(idx_d, tf.expand_dims(((x-x0_f) * (y-y0_f)), 1), [self.num_batch*self.out_height*self.out_width, 1])
value_all = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
weight_all = tf.clip_by_value(tf.add_n([wa, wb, wc, wd]),1e-5,1e+10)
flag = tf.less_equal(weight_all, 1e-5* tf.ones_like(weight_all))
flag = tf.cast(flag, tf.float32)
im_org = tf.reshape(im_org, [-1,self.num_channels])
output = tf.add(tf.div(value_all, weight_all), tf.multiply(im_org, flag))
return output
Dense_Transformer_Network.py 文件源码
python
阅读 35
收藏 0
点赞 0
评论 0
评论列表
文章目录