pot.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:adagan 作者: tolstikhin 项目源码 文件源码
def dcgan_like_arch(self, opts, noise, is_training, reuse, keep_prob):
        output_shape = self._data.data_shape
        num_units = opts['g_num_filters']

        batch_size = tf.shape(noise)[0]
        num_layers = opts['g_num_layers']
        if opts['g_arch'] == 'dcgan':
            height = output_shape[0] / 2**num_layers
            width = output_shape[1] / 2**num_layers
        elif opts['g_arch'] == 'dcgan_mod':
            height = output_shape[0] / 2**(num_layers-1)
            width = output_shape[1] / 2**(num_layers-1)
        else:
            assert False

        h0 = ops.linear(
            opts, noise, num_units * height * width, scope='h0_lin')
        h0 = tf.reshape(h0, [-1, height, width, num_units])
        h0 = tf.nn.relu(h0)
        layer_x = h0
        for i in xrange(num_layers-1):
            scale = 2**(i+1)
            if opts['g_stride1_deconv']:
                # Sylvain, I'm worried about this part!
                _out_shape = [batch_size, height * scale / 2,
                              width * scale / 2, num_units / scale * 2]
                layer_x = ops.deconv2d(
                    opts, layer_x, _out_shape, d_h=1, d_w=1,
                    scope='h%d_deconv_1x1' % i)
                layer_x = tf.nn.relu(layer_x)
            _out_shape = [batch_size, height * scale, width * scale, num_units / scale]
            layer_x = ops.deconv2d(opts, layer_x, _out_shape, scope='h%d_deconv' % i)
            if opts['batch_norm']:
                layer_x = ops.batch_norm(opts, layer_x, is_training, reuse, scope='bn%d' % i)
            layer_x = tf.nn.relu(layer_x)
            if opts['dropout']:
                _keep_prob = tf.minimum(
                    1., 0.9 - (0.9 - keep_prob) * float(i + 1) / (num_layers - 1))
                layer_x = tf.nn.dropout(layer_x, _keep_prob)

        _out_shape = [batch_size] + list(output_shape)
        if opts['g_arch'] == 'dcgan':
            last_h = ops.deconv2d(
                opts, layer_x, _out_shape, scope='hlast_deconv')
        elif opts['g_arch'] == 'dcgan_mod':
            last_h = ops.deconv2d(
                opts, layer_x, _out_shape, d_h=1, d_w=1, scope='hlast_deconv')
        else:
            assert False

        if opts['input_normalize_sym']:
            return tf.nn.tanh(last_h)
        else:
            return tf.nn.sigmoid(last_h)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号