srez_model.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:tensorflow-srgan 作者: olgaliak 项目源码 文件源码
def create_generator_loss(disc_output, gene_output, features):
    # I.e. did we fool the discriminator?
    cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=disc_output, logits=tf.ones_like(disc_output))
    gene_ce_loss  = tf.reduce_mean(cross_entropy, name='gene_ce_loss')

    # I.e. does the result look like the feature?
    K = int(gene_output.get_shape()[1])//int(features.get_shape()[1])
    assert K == 2 or K == 4 or K == 8    
    downscaled = _downscale(gene_output, K)

    gene_l1_loss  = tf.reduce_mean(tf.abs(downscaled - features), name='gene_l1_loss')

    gene_loss     = tf.add((1.0 - FLAGS.gene_l1_factor) * gene_ce_loss,
                           FLAGS.gene_l1_factor * gene_l1_loss, name='gene_loss')

    return gene_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号