resnet_imagenet_model_wrapper.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:tfplus 作者: renmengye 项目源码 文件源码
def build_loss(self, inp, output):
        y_gt = inp['y_gt']
        y_out = output['y_out']
        ce = tfplus.nn.CE()({'y_gt': y_gt, 'y_out': y_out})
        num_ex_f = tf.to_float(tf.shape(inp['x'])[0])
        ce = tf.reduce_sum(ce) / num_ex_f
        self.add_loss(ce)
        total_loss = self.get_loss()
        self.register_var('loss', total_loss)

        ans = tf.argmax(y_gt, 1)
        correct = tf.equal(ans, tf.argmax(y_out, 1))
        top5_acc = tf.reduce_sum(tf.to_float(
            tf.nn.in_top_k(y_out, ans, 5))) / num_ex_f
        self.register_var('top5_acc', top5_acc)
        acc = tf.reduce_sum(tf.to_float(correct)) / num_ex_f
        self.register_var('acc', acc)
        return total_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号