def _jsma_impl(model, x, yind, epochs, eps, clip_min, clip_max, score_fn):
def _cond(i, xadv):
return tf.less(i, epochs)
def _body(i, xadv):
ybar = model(xadv)
dy_dx = tf.gradients(ybar, xadv)[0]
# gradients of target w.r.t input
yt = tf.gather_nd(ybar, yind)
dt_dx = tf.gradients(yt, xadv)[0]
# gradients of non-targets w.r.t input
do_dx = dy_dx - dt_dx
c0 = tf.logical_or(eps < 0, xadv < clip_max)
c1 = tf.logical_or(eps > 0, xadv > clip_min)
cond = tf.reduce_all([dt_dx >= 0, do_dx <= 0, c0, c1], axis=0)
cond = tf.to_float(cond)
# saliency score for each pixel
score = cond * score_fn(dt_dx, do_dx)
shape = score.get_shape().as_list()
dim = _prod(shape[1:])
score = tf.reshape(score, [-1, dim])
# find the pixel with the highest saliency score
ind = tf.argmax(score, axis=1)
dx = tf.one_hot(ind, dim, on_value=eps, off_value=0.0)
dx = tf.reshape(dx, [-1] + shape[1:])
xadv = tf.stop_gradient(xadv + dx)
xadv = tf.clip_by_value(xadv, clip_min, clip_max)
return i+1, xadv
_, xadv = tf.while_loop(_cond, _body, (0, tf.identity(x)),
back_prop=False, name='_jsma_batch')
return xadv
saliency_map.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录