def apply(self, is_train, x, mask=None):
if self.key_mapper is not None:
with tf.variable_scope("map_keys"):
keys = self.key_mapper.apply(is_train, x, mask)
else:
keys = x
weights = tf.get_variable("weights", (keys.shape.as_list()[-1], self.n_encodings), dtype=tf.float32,
initializer=get_keras_initialization(self.init))
dist = tf.tensordot(keys, weights, axes=[[2], [0]]) # (batch, x_words, n_encoding)
if self.bias:
dist += tf.get_variable("bias", (1, 1, self.n_encodings),
dtype=tf.float32, initializer=tf.zeros_initializer())
if mask is not None:
bool_mask = tf.expand_dims(tf.cast(tf.sequence_mask(mask, tf.shape(x)[1]), tf.float32), 2)
dist = bool_mask * bool_mask + (1 - bool_mask) * VERY_NEGATIVE_NUMBER
dist = tf.nn.softmax(dist, dim=1)
out = tf.einsum("ajk,ajn->ank", x, dist) # (batch, n_encoding, feature)
if self.post_process is not None:
with tf.variable_scope("post_process"):
out = self.post_process.apply(is_train, out)
return out
评论列表
文章目录