def prune(self, size):
if size >= self.size():
return self
# Only keep the `size` most frequent entries.
freq = torch.Tensor(
[self.frequencies[i] for i in range(len(self.frequencies))])
_, idx = torch.sort(freq, 0, True)
newDict = Dict()
# Add special entries in all cases.
for i in self.special:
newDict.addSpecial(self.idxToLabel[i])
for i in idx[:size]:
newDict.add(self.idxToLabel[i])
return newDict
# Convert `labels` to indices. Use `unkWord` if not found.
# Optionally insert `bosWord` at the beginning and `eosWord` at the .
评论列表
文章目录