def __prepare__(self):
"""
"""
conversations = open(path.join(self.BASE_PATH, self.CONVS_FILE), 'r').readlines()
movie_lines = open(path.join(self.BASE_PATH, self.LINES_FILE), 'r').readlines()
tbt = TreebankWordTokenizer().tokenize
self.words_set = set()
self.lines_dict = {}
for i, line in enumerate(movie_lines):
parts = map(lambda x: x.strip(), line.lower().split(self.FILE_SEP))
tokens = tbt(parts[-1])
self.lines_dict[parts[0]] = tokens
self.words_set |= set(tokens)
self.word2idx = {}
self.word2idx[self.PAD_TOKEN] = 0
self.word2idx[self.EOS_TOKEN] = 1
self.word2idx[self.GO_TOKEN] = 2
for i, word in enumerate(self.words_set):
self.word2idx[word] = i + 3
self.idx2word = [0] * len(self.word2idx)
for w, i in self.word2idx.items():
self.idx2word[i] = w
# extract pairs of lines in a conversation (s0, s1, s2) -> {(s0, s1), (s1, s2)}
utt_pairs = []
for line in conversations:
parts = map(lambda x: x[1:-1], map(lambda x: x.strip(), line.lower().split(self.FILE_SEP))[-1][1:-1].split(', '))
utt_pairs += list(pairwise(parts))
utt_pairs = np.random.permutation(utt_pairs)
train_utt_pairs = utt_pairs[self.VAL_COUNT:]
self.val_pairs = utt_pairs[:self.VAL_COUNT]
def find_bucket(enc_size, dec_size, buckets):
return next(dropwhile(lambda x: enc_size > x[0] or dec_size > x[1], buckets), None)
for pair in train_utt_pairs:
bckt = find_bucket(len(self.lines_dict[pair[0]]), len(self.lines_dict[pair[1]]), self.bucket_sizes)
if bckt is None:
self.bucket_pairs[(-1, -1)].append(pair)
else:
self.bucket_pairs[bckt].append(pair)
self.bucket_ordering = []
for bckt, _ in sorted(map(lambda x: (x[0], len(x[1])), self.bucket_pairs.items()), key=lambda x: x[1], reverse=True):
self.bucket_ordering.append(bckt)
评论列表
文章目录