encoder_decoder.py 文件源码

python
阅读 48 收藏 0 点赞 0 评论 0

项目:seqmod 作者: emanjavacas 项目源码 文件源码
def schedule_sampling(self, prev, dec_out):
        """
        Resample n inputs to next iteration from the model itself. N is itself
        sampled from a bernoulli independently for each example in the batch
        with weights equal to the model's variable self.scheduled_rate.

        Parameters:
        -----------

        - prev: torch.LongTensor(batch_size)
        - dec_out: torch.Tensor(batch_size x hid_dim)

        Returns: partially resampled input
        --------
        - prev: torch.LongTensor(batch_size)
        """
        prev, dec_out = prev.data, dec_out.data  # don't register computation

        keep_mask = torch.bernoulli(
            torch.zeros_like(prev).float() + self.exposure_rate) == 1

        # return if no sampling is necessary
        if len(keep_mask.nonzero()) == len(prev):
            return prev

        sampled = self.decoder.project(
            Variable(dec_out, volatile=True)).max(1)[1].data

        if keep_mask.nonzero().dim() == 0:  # return all sampled
            return sampled

        keep_mask = keep_mask.nonzero().squeeze(1)
        sampled[keep_mask] = prev[keep_mask]

        return sampled
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号