pot.py 文件源码

python
阅读 64 收藏 0 点赞 0 评论 0

项目:adagan 作者: tolstikhin 项目源码 文件源码
def began_dec(self, opts, noise, is_training, reuse, keep_prob):
        """ Architecture reported here: https://arxiv.org/pdf/1703.10717.pdf
        """

        output_shape = self._data.data_shape
        num_units = opts['g_num_filters']
        num_layers = opts['g_num_layers']
        batch_size = tf.shape(noise)[0]

        h0 = ops.linear(
            opts, noise, num_units * 8 * 8, scope='h0_lin')
        h0 = tf.reshape(h0, [-1, 8, 8, num_units])
        layer_x = h0
        for i in xrange(num_layers):
            if i % 3 < 2:
                # Don't change resolution
                layer_x = ops.conv2d(opts, layer_x, num_units, d_h=1, d_w=1, scope='h%d_conv' % i)
                layer_x = tf.nn.elu(layer_x)
            else:
                if i != num_layers - 1:
                    # Upsampling by factor of 2 with NN
                    scale = 2 ** (i / 3 + 1)
                    layer_x = ops.upsample_nn(layer_x, [scale * 8, scale * 8],
                                              scope='h%d_upsample' % i, reuse=reuse)
                    # Skip connection
                    append = ops.upsample_nn(h0, [scale * 8, scale * 8],
                                              scope='h%d_skipup' % i, reuse=reuse)
                    layer_x = tf.concat([layer_x, append], axis=3)

        last_h = ops.conv2d(opts, layer_x, output_shape[-1], d_h=1, d_w=1, scope='hlast_conv')

        if opts['input_normalize_sym']:
            return tf.nn.tanh(last_h)
        else:
            return tf.nn.sigmoid(last_h)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号