def lookup_shift(x,
context,
shift=True,
scale=True,
scope=None,
reuse=None):
B = context._shape_as_list()[-1]
C = x._shape_as_list()[-1]
ndim = len(x.shape)
var_shape = [B] + [1] * (ndim - 2) + [C]
with tf.variable_scope(scope, 'lookup_shift', reuse=reuse):
output = x
ids = tf.argmax(context, -1)
if scale:
gamma = tf.get_variable('gamma', var_shape, initializer=tf.ones_initializer)
output *= tf.nn.embedding_lookup(gamma, ids)
if shift:
beta = tf.get_variable('beta', var_shape, initializer=tf.zeros_initializer)
output += tf.nn.embedding_lookup(beta, ids)
return output
评论列表
文章目录