category_loss.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:HyperGAN 作者: 255BITS 项目源码 文件源码
def categories_loss(self, categories, layer):
        gan = self.gan
        loss = 0
        batch_size = gan.batch_size()
        def split(layer):
            start = 0
            ret = []
            for category in categories:
                count = int(category.get_shape()[1])
                ret.append(tf.slice(layer, [0, start], [batch_size, count]))
                start += count
            return ret

        for category,layer_s in zip(categories, split(layer)):
            size = int(category.get_shape()[1])
            category_prior = tf.ones([batch_size, size])*np.float32(1./size)
            logli_prior = tf.reduce_sum(tf.log(category_prior + TINY) * category, axis=1)
            layer_softmax = tf.nn.softmax(layer_s)
            logli = tf.reduce_sum(tf.log(layer_softmax+TINY)*category, axis=1)
            disc_ent = tf.reduce_mean(-logli_prior)
            disc_cross_ent =  tf.reduce_mean(-logli)

            loss += disc_ent - disc_cross_ent
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号