def get_data(self, path, batch_size=20):
# Add words to the dictionary
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
self.dictionary.add_word(word)
# Tokenize the file content
ids = torch.LongTensor(tokens)
token = 0
with open(path, 'r') as f:
for line in f:
words = line.split() + ['<eos>']
for word in words:
ids[token] = self.dictionary.word2idx[word]
token += 1
num_batches = ids.size(0) // batch_size
ids = ids[:num_batches*batch_size]
return ids.view(batch_size, -1)
评论列表
文章目录