vae.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:KATE 作者: hugochan 项目源码 文件源码
def vae_loss(self, x, x_decoded_mean):
        xent_loss =  K.sum(K.binary_crossentropy(x_decoded_mean, x), axis=-1)
        kl_loss = - 0.5 * K.sum(1 + self.z_log_var - K.square(self.z_mean) - K.exp(self.z_log_var), axis=-1)

        return xent_loss + kl_loss

    # def weighted_vae_loss(self, feature_weights):
    #     def loss(y_true, y_pred):
    #         try:
    #             x = K.binary_crossentropy(y_pred, y_true)
    #             y = tf.Variable(feature_weights.astype('float32'))
    #             # y2 = y_true / K.sum(y_true)
    #             # import pdb;pdb.set_trace()
    #             xent_loss = K.dot(x, y)
    #             kl_loss = - 0.5 * K.sum(1 + self.z_log_var - K.square(self.z_mean) - K.exp(self.z_log_var), axis=-1)
    #         except Exception as e:
    #             print e
    #             import pdb;pdb.set_trace()
    #         return xent_loss + kl_loss
    #     return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号