shalo_base.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:shalo 作者: henryre 项目源码 文件源码
def _get_embedding(self):
        """
        Return embedding tensor (either constant or variable)
        Row 0 is 0 vector for no token
        Row 1 is random initialization for UNKNOWN
        Rows 2 : 2 + len(self.embedding_words) are pretrained initialization
        Remaining rows are random initialization
        """
        zero = tf.constant(0.0, dtype=tf.float32, shape=(1, self.d))
        s = self.seed - 1
        unk = tf.Variable(tf.random_normal((1, self.d), stddev=SD, seed=s))
        pretrain = tf.Variable(self.embeddings_train, dtype=tf.float32)
        vecs = [zero, unk, pretrain]
        n_r = self.word_dict.num_words() - len(self.embedding_words_train)
        if n_r > 0:
            r = tf.Variable(tf.random_normal((n_r, self.d), stddev=SD, seed=s))
            vecs.append(r)
        self.U = tf.concat(vecs, axis=0, name='embedding_matrix')
        return self.U
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号