def prune(self, size):
"Return a new dictionary with the `size` most frequent entries."
if size >= self.size():
return self
# Only keep the `size` most frequent entries.
freq = torch.Tensor(
[self.frequencies[i] for i in range(len(self.frequencies))])
_, idx = torch.sort(freq, 0, True)
newDict = Dict()
newDict.lower = self.lower
# Add special entries in all cases.
for i in self.special:
newDict.addSpecial(self.idxToLabel[i])
for i in idx[:size]:
newDict.add(self.idxToLabel[i])
return newDict
评论列表
文章目录