def _embed_sentences(self):
"""Tensorflow implementation of Simple but Tough-to-Beat Baseline"""
# Get word features
word_embeddings = self._get_embedding()
word_feats = tf.nn.embedding_lookup(word_embeddings, self.input)
# Get marginal estimates and scaling term
batch_size = tf.shape(word_feats)[0]
a = tf.pow(10.0, self._get_a_exp())
p = tf.constant(self.marginals, dtype=tf.float32, name='marginals')
q = tf.reshape(
a / (a + tf.nn.embedding_lookup(p, self.input)),
(batch_size, self.mx_len, 1)
)
# Compute initial sentence embedding
z = tf.reshape(1.0 / tf.to_float(self.input_lengths), (batch_size, 1))
S = z * tf.reduce_sum(q * word_feats, axis=1)
# Compute common component
S_centered = S - tf.reduce_mean(S, axis=0)
_, _, V = tf.svd(S_centered, full_matrices=False, compute_uv=True)
self.tf_ccx = tf.stop_gradient(tf.gather(tf.transpose(V), 0))
# Common component removal
ccx = tf.reshape(self._get_common_component(), (1, self.d))
sv = {'embeddings': word_embeddings, 'a': a, 'p': p, 'ccx': ccx}
return S - tf.matmul(S, ccx * tf.transpose(ccx)), sv
评论列表
文章目录