def __init__(
self,
images,
logits,
bounds,
channel_axis=3,
preprocessing=(0, 1)):
super(TensorFlowModel, self).__init__(bounds=bounds,
channel_axis=channel_axis,
preprocessing=preprocessing)
# delay import until class is instantiated
import tensorflow as tf
session = tf.get_default_session()
if session is None:
session = tf.Session(graph=images.graph)
self._created_session = True
else:
self._created_session = False
with session.graph.as_default():
self._session = session
self._images = images
self._batch_logits = logits
self._logits = tf.squeeze(logits, axis=0)
self._label = tf.placeholder(tf.int64, (), name='label')
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=self._label[tf.newaxis],
logits=self._logits[tf.newaxis])
self._loss = tf.squeeze(loss, axis=0)
gradients = tf.gradients(loss, images)
assert len(gradients) == 1
self._gradient = tf.squeeze(gradients[0], axis=0)
self._bw_gradient_pre = tf.placeholder(tf.float32, self._logits.shape) # noqa: E501
bw_loss = tf.reduce_sum(self._logits * self._bw_gradient_pre)
bw_gradients = tf.gradients(bw_loss, images)
assert len(bw_gradients) == 1
self._bw_gradient = tf.squeeze(bw_gradients[0], axis=0)
评论列表
文章目录