gan.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:adagan 作者: tolstikhin 项目源码 文件源码
def generator(self, opts, noise, is_training, reuse=False):
        """Generator function, suitable for simple picture experiments.

        Args:
            noise: [num_points, dim] array, where dim is dimensionality of the
                latent noise space.
            is_training: bool, defines whether to use batch_norm in the train
                or test mode.
        Returns:
            [num_points, dim1, dim2, dim3] array, where the first coordinate
            indexes the points, which all are of the shape (dim1, dim2, dim3).
        """

        output_shape = self._data.data_shape # (dim1, dim2, dim3)
        # Computing the number of noise vectors on-the-go
        dim1 = tf.shape(noise)[0]
        num_filters = opts['g_num_filters']

        with tf.variable_scope("GENERATOR", reuse=reuse):

            height = output_shape[0] / 4
            width = output_shape[1] / 4
            h0 = ops.linear(opts, noise, num_filters * height * width,
                            scope='h0_lin')
            h0 = tf.reshape(h0, [-1, height, width, num_filters])
            h0 = ops.batch_norm(opts, h0, is_training, reuse, scope='bn_layer1')
            # h0 = tf.nn.relu(h0)
            h0 = ops.lrelu(h0)
            _out_shape = [dim1, height * 2, width * 2, num_filters / 2]
            # for 28 x 28 does 7 x 7 --> 14 x 14
            h1 = ops.deconv2d(opts, h0, _out_shape, scope='h1_deconv')
            h1 = ops.batch_norm(opts, h1, is_training, reuse, scope='bn_layer2')
            # h1 = tf.nn.relu(h1)
            h1 = ops.lrelu(h1)
            _out_shape = [dim1, height * 4, width * 4, num_filters / 4]
            # for 28 x 28 does 14 x 14 --> 28 x 28
            h2 = ops.deconv2d(opts, h1, _out_shape, scope='h2_deconv')
            h2 = ops.batch_norm(opts, h2, is_training, reuse, scope='bn_layer3')
            # h2 = tf.nn.relu(h2)
            h2 = ops.lrelu(h2)
            _out_shape = [dim1] + list(output_shape)
            # data_shape[0] x data_shape[1] x ? -> data_shape
            h3 = ops.deconv2d(opts, h2, _out_shape,
                              d_h=1, d_w=1, scope='h3_deconv')
            h3 = ops.batch_norm(opts, h3, is_training, reuse, scope='bn_layer4')

        if opts['input_normalize_sym']:
            return tf.nn.tanh(h3)
        else:
            return tf.nn.sigmoid(h3)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号