ttbb_models.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:shalo 作者: henryre 项目源码 文件源码
def _embed_sentences(self):
        """Tensorflow implementation of Simple but Tough-to-Beat Baseline"""
        # Get word features
        word_embeddings = self._get_embedding()
        word_feats      = tf.nn.embedding_lookup(word_embeddings, self.input)
        # Get marginal estimates and scaling term
        batch_size = tf.shape(word_feats)[0]
        a = tf.pow(10.0, self._get_a_exp())
        p = tf.constant(self.marginals, dtype=tf.float32, name='marginals')
        q = tf.reshape(
            a / (a + tf.nn.embedding_lookup(p, self.input)),
            (batch_size, self.mx_len, 1)
        )
        # Compute initial sentence embedding
        z = tf.reshape(1.0 / tf.to_float(self.input_lengths), (batch_size, 1))
        S = z * tf.reduce_sum(q * word_feats, axis=1)
        # Compute common component
        S_centered = S - tf.reduce_mean(S, axis=0)
        _, _, V = tf.svd(S_centered, full_matrices=False, compute_uv=True)
        self.tf_ccx = tf.stop_gradient(tf.gather(tf.transpose(V), 0))
        # Common component removal
        ccx = tf.reshape(self._get_common_component(), (1, self.d))
        sv = {'embeddings': word_embeddings, 'a': a, 'p': p, 'ccx': ccx}
        return S - tf.matmul(S, ccx * tf.transpose(ccx)), sv
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号