def _get_hashed_indices(self, original_indices):
def _hash(x, seed):
# TODO: integrate with padding index
result = murmurhash3_32(x, seed=seed)
result[self.padding_idx] = 0
return result % self.compressed_num_embeddings
if self._hashes is None:
indices = np.arange(self.num_embeddings, dtype=np.int32)
hashes = np.stack([_hash(indices, seed)
for seed in self._masks],
axis=1).astype(np.int64)
assert hashes[self.padding_idx].sum() == 0
self._hashes = torch.from_numpy(hashes)
if original_indices.is_cuda:
self._hashes = self._hashes.cuda()
hashed_indices = torch.index_select(self._hashes,
0,
original_indices.squeeze())
return hashed_indices
评论列表
文章目录