def attention_bias(inputs, mode, inf=-1e9, name=None):
""" A bias tensor used in attention mechanism
:param inputs:
:param mode:
:param inf:
:param name:
:returns:
"""
with tf.name_scope(name, default_name="attention_bias", values=[inputs]):
if mode == "causal":
length = inputs
lower_triangle = tf.matrix_band_part(
tf.ones([length, length]), -1, 0
)
ret = inf * (1.0 - lower_triangle)
return tf.reshape(ret, [1, 1, length, length])
elif mode == "masking":
mask = inputs
ret = (1.0 - mask) * inf
return tf.expand_dims(tf.expand_dims(ret, 1), 1)
elif mode == "proximal":
length = inputs
r = tf.to_float(tf.range(length))
diff = tf.expand_dims(r, 0) - tf.expand_dims(r, 1)
m = tf.expand_dims(tf.expand_dims(-tf.log(1 + tf.abs(diff)), 0), 0)
return m
elif mode == "distance":
length, distance = inputs
distance = tf.where(distance > length, 0, distance)
distance = tf.cast(distance, tf.int64)
lower_triangle = tf.matrix_band_part(
tf.ones([length, length]), -1, 0
)
mask_triangle = 1.0 - tf.matrix_band_part(
tf.ones([length, length]), distance - 1, 0
)
ret = inf * (1.0 - lower_triangle + mask_triangle)
return tf.reshape(ret, [1, 1, length, length])
else:
raise ValueError("Unknown mode %s" % mode)
评论列表
文章目录