def up_block(block_fn, filters):
def f(inputs, down):
inputs_ = tcl.conv2d_transpose(down, # double size of 'down'
num_outputs = inputs.shape.as_list()[3],
kernel_size = (2, 2),
stride = (2, 2),
padding = 'SAME')
x = tf.concat([inputs, inputs_], axis = 3)
x = block_fn(filters)(x)
x = block_fn(filters)(x)
return x # same size of 'inputs'
return f
评论列表
文章目录