tensorflow.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:foolbox 作者: bethgelab 项目源码 文件源码
def __init__(
            self,
            images,
            logits,
            bounds,
            channel_axis=3,
            preprocessing=(0, 1)):

        super(TensorFlowModel, self).__init__(bounds=bounds,
                                              channel_axis=channel_axis,
                                              preprocessing=preprocessing)

        # delay import until class is instantiated
        import tensorflow as tf

        session = tf.get_default_session()
        if session is None:
            session = tf.Session(graph=images.graph)
            self._created_session = True
        else:
            self._created_session = False

        with session.graph.as_default():
            self._session = session
            self._images = images
            self._batch_logits = logits
            self._logits = tf.squeeze(logits, axis=0)
            self._label = tf.placeholder(tf.int64, (), name='label')

            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=self._label[tf.newaxis],
                logits=self._logits[tf.newaxis])
            self._loss = tf.squeeze(loss, axis=0)
            gradients = tf.gradients(loss, images)
            assert len(gradients) == 1
            self._gradient = tf.squeeze(gradients[0], axis=0)

            self._bw_gradient_pre = tf.placeholder(tf.float32, self._logits.shape)  # noqa: E501
            bw_loss = tf.reduce_sum(self._logits * self._bw_gradient_pre)
            bw_gradients = tf.gradients(bw_loss, images)
            assert len(bw_gradients) == 1
            self._bw_gradient = tf.squeeze(bw_gradients[0], axis=0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号