def apply(self, is_train, x, mask=None):
if self.key_mapper is not None:
with tf.variable_scope("map_keys"):
keys = self.key_mapper.apply(is_train, x, mask)
else:
keys = x
weights = tf.get_variable("weights", keys.shape.as_list()[-1], dtype=tf.float32,
initializer=get_keras_initialization(self.init))
dist = tf.tensordot(keys, weights, axes=[[2], [0]]) # (batch, x_words)
dist = exp_mask(dist, mask)
dist = tf.nn.softmax(dist)
out = tf.einsum("ajk,aj->ak", x, dist) # (batch, x_dim)
if self.post_process is not None:
with tf.variable_scope("post_process"):
out = self.post_process.apply(is_train, out)
return out
评论列表
文章目录