layers.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:spotlight 作者: maciejkula 项目源码 文件源码
def _get_hashed_indices(self, original_indices):

        def _hash(x, seed):

            # TODO: integrate with padding index
            result = murmurhash3_32(x, seed=seed)
            result[self.padding_idx] = 0

            return result % self.compressed_num_embeddings

        if self._hashes is None:
            indices = np.arange(self.num_embeddings, dtype=np.int32)
            hashes = np.stack([_hash(indices, seed)
                               for seed in self._masks],
                              axis=1).astype(np.int64)
            assert hashes[self.padding_idx].sum() == 0

            self._hashes = torch.from_numpy(hashes)

            if original_indices.is_cuda:
                self._hashes = self._hashes.cuda()

        hashed_indices = torch.index_select(self._hashes,
                                            0,
                                            original_indices.squeeze())

        return hashed_indices
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号