pot.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:adagan 作者: tolstikhin 项目源码 文件源码
def ali_deconv(self, opts, noise, is_training, reuse, keep_prob):
        output_shape = self._data.data_shape

        batch_size = tf.shape(noise)[0]
        noise_size = int(noise.get_shape()[1])
        data_height = output_shape[0]
        data_width = output_shape[1]
        data_channels = output_shape[2]

        noise = tf.reshape(noise, [-1, 1, 1, noise_size])

        num_units = opts['g_num_filters']
        layer_params = []
        layer_params.append([4, 1, num_units])
        layer_params.append([4, 2, num_units / 2])
        layer_params.append([4, 1, num_units / 4])
        layer_params.append([4, 2, num_units / 8])
        layer_params.append([5, 1, num_units / 8])
        # For convolution: (n - k) / stride + 1 = s
        # For transposed: (s - 1) * stride + k = n
        layer_x = noise
        height = 1
        width = 1
        for i, (kernel, stride, channels) in enumerate(layer_params):
            height = (height - 1) * stride + kernel
            width = height
            layer_x = ops.deconv2d(
                opts, layer_x, [batch_size, height, width, channels], d_h=stride, d_w=stride,
                scope='h%d_deconv' % i, conv_filters_dim=kernel, padding='VALID')
            if opts['batch_norm']:
                layer_x = ops.batch_norm(opts, layer_x, is_training, reuse, scope='bn%d' % i)
            layer_x = ops.lrelu(layer_x, 0.1)
        assert height == data_height
        assert width == data_width

        # Then two 1x1 convolutions.
        layer_x = ops.conv2d(opts, layer_x, num_units / 8, d_h=1, d_w=1, scope='conv2d_1x1', conv_filters_dim=1)
        if opts['batch_norm']:
            layer_x = ops.batch_norm(opts, layer_x, is_training, reuse, scope='bnlast')
        layer_x = ops.lrelu(layer_x, 0.1)
        layer_x = ops.conv2d(opts, layer_x, data_channels, d_h=1, d_w=1, scope='conv2d_1x1_2', conv_filters_dim=1)

        if opts['input_normalize_sym']:
            return tf.nn.tanh(layer_x)
        else:
            return tf.nn.sigmoid(layer_x)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号