negative_sampling.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:lda2vec-tf 作者: meereeum 项目源码 文件源码
def __call__(self, embed, train_labels):

        with tf.name_scope("negative_sampling"):
            # mask out skip or OOV
            # if switched on, this yields ...
            # UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.

            # mask = tf.greater(train_labels, NegativeSampling.IGNORE_LABEL_MAX)
            # # mask = tf.not_equal(train_labels, NegativeSampling.IGNORE_LABEL)
            # embed = tf.boolean_mask(embed, mask)
            # train_labels = tf.expand_dims(tf.boolean_mask(train_labels, mask), -1)
            train_labels = tf.expand_dims(train_labels, -1)

            # Compute the average NCE loss for the batch.
            # tf.nce_loss automatically draws a new sample of the negative labels each
            # time we evaluate the loss.
            # By default this uses a log-uniform (Zipfian) distribution for sampling
            # and therefore assumes labels are sorted - which they are!

            sampler = (self.freqs if self.freqs is None # default to unigram
                       else tf.nn.fixed_unigram_candidate_sampler(
                               train_labels, num_true=1, num_sampled=self.sample_size,
                               unique=True, range_max=self.vocab_size,
                               #num_reserved_ids=2, # skip or OoV
                               # ^ only if not in unigrams
                               distortion=self.power, unigrams=list(self.freqs)))

            loss = tf.reduce_mean(
                    tf.nn.nce_loss(self.nce_weights, self.nce_biases,
                                   embed, # summed doc and context embedding
                                   train_labels, self.sample_size, self.vocab_size,
                                   sampled_values=sampler), # log-unigram if not specificed
                    name="nce_batch_loss")
            # TODO negative sampling versus NCE
            # TODO uniform vs. Zipf with exponent `distortion` param
            #https://www.tensorflow.org/versions/r0.12/api_docs/python/nn.html#log_uniform_candidate_sampler

        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号