def setup_transform_net(self, input_var=None):
transform_net = InputLayer(shape=self.shape, input_var=input_var)
transform_net = style_conv_block(transform_net, self.num_styles, 32, 9, 1)
transform_net = style_conv_block(transform_net, self.num_styles, 64, 3, 2)
transform_net = style_conv_block(transform_net, self.num_styles, 128, 3, 2)
for _ in range(5):
transform_net = residual_block(transform_net, self.num_styles)
transform_net = nn_upsample(transform_net, self.num_styles)
transform_net = nn_upsample(transform_net, self.num_styles)
if self.net_type == 0:
transform_net = style_conv_block(transform_net, self.num_styles, 3, 9, 1, tanh)
transform_net = ExpressionLayer(transform_net, lambda X: 150.*X, output_shape=None)
elif self.net_type == 1:
transform_net = style_conv_block(transform_net, self.num_styles, 3, 9, 1, sigmoid)
self.network['transform_net'] = transform_net
评论列表
文章目录