def create_generator(hr_image_bilinear, num_channels, cfg):
layers = []
print(hr_image_bilinear.get_shape())
conv = slim.conv2d(hr_image_bilinear, cfg.ngf, [3,3], stride = 2, scope = 'encoder0')
layers.append(conv)
layers_specs = [
cfg.ngf*2,
cfg.ngf*4,
cfg.ngf*8,
cfg.ngf*8,
cfg.ngf*8,
cfg.ngf*8,
]
for idx, out_channels in enumerate(layers_specs):
with slim.arg_scope([slim.conv2d], activation_fn = lrelu, stride = 2, padding = 'VALID'):
conv = conv2d(layers[-1], out_channels, scope = 'encoder%d'%(idx+1))
print(conv.get_shape())
layers.append(conv)
### decoder part
layers_specs = [
(cfg.ngf*8, 0.5),
(cfg.ngf*8, 0.5),
(cfg.ngf*8, 0.0),
(cfg.ngf*4, 0.0),
(cfg.ngf*2, 0.0),
(cfg.ngf, 0.0)
]
num_encoder_layers = len(layers)
for decoder_layer_idx, (out_channels, dropout) in enumerate(layers_specs):
skip_layer = num_encoder_layers - decoder_layer_idx - 1
with slim.arg_scope([slim.conv2d], activation_fn = lrelu):
if decoder_layer_idx == 0:
input = layers[-1]
else:
input = tf.concat([layers[-1], layers[skip_layer]], axis = 3)
output = upsample_layer(input, out_channels, mode = 'deconv')
print(output.get_shape())
if dropout > 0.0:
output = tf.nn.dropout(output, keep_prob = 1 - dropout)
layers.append(output)
input = tf.concat([layers[-1],layers[0]], axis = 3)
output = slim.conv2d_transpose(input, num_channels, [4,4], stride = 2, activation_fn = tf.tanh)
return output
评论列表
文章目录