def _jsma2_impl(model, x, yind, epochs, eps, clip_min, clip_max, score_fn):
def _cond(k, xadv):
return tf.less(k, epochs)
def _body(k, xadv):
ybar = model(xadv)
dy_dx = tf.gradients(ybar, xadv)[0]
# gradients of target w.r.t input
yt = tf.gather_nd(ybar, yind)
dt_dx = tf.gradients(yt, xadv)[0]
# gradients of non-targets w.r.t input
do_dx = dy_dx - dt_dx
c0 = tf.logical_or(eps < 0, xadv < clip_max)
c1 = tf.logical_or(eps > 0, xadv > clip_min)
cond = tf.reduce_all([dt_dx >= 0, do_dx <= 0, c0, c1], axis=0)
cond = tf.to_float(cond)
# saliency score for each pixel
score = cond * score_fn(dt_dx, do_dx)
shape = score.get_shape().as_list()
dim = _prod(shape[1:])
score = tf.reshape(score, [-1, dim])
a = tf.expand_dims(score, axis=1)
b = tf.expand_dims(score, axis=2)
score2 = tf.reshape(a + b, [-1, dim*dim])
ij = tf.argmax(score2, axis=1)
i = tf.to_int32(ij / dim)
j = tf.to_int32(ij) % dim
dxi = tf.one_hot(i, dim, on_value=eps, off_value=0.0)
dxj = tf.one_hot(j, dim, on_value=eps, off_value=0.0)
dx = tf.reshape(dxi + dxj, [-1] + shape[1:])
xadv = tf.stop_gradient(xadv + dx)
xadv = tf.clip_by_value(xadv, clip_min, clip_max)
return k+1, xadv
_, xadv = tf.while_loop(_cond, _body, (0, tf.identity(x)),
back_prop=False, name='_jsma2_batch')
return xadv
saliency_map.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录