cornell.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:seq2seq-lasagne 作者: erfannoury 项目源码 文件源码
def __prepare__(self):
        """

        """
        conversations = open(path.join(self.BASE_PATH, self.CONVS_FILE), 'r').readlines()
        movie_lines = open(path.join(self.BASE_PATH, self.LINES_FILE), 'r').readlines()
        tbt = TreebankWordTokenizer().tokenize
        self.words_set = set()
        self.lines_dict = {}
        for i, line in enumerate(movie_lines):
            parts = map(lambda x: x.strip(), line.lower().split(self.FILE_SEP))
            tokens = tbt(parts[-1])
            self.lines_dict[parts[0]] = tokens
            self.words_set |= set(tokens)
        self.word2idx = {}
        self.word2idx[self.PAD_TOKEN] = 0
        self.word2idx[self.EOS_TOKEN] = 1
        self.word2idx[self.GO_TOKEN] = 2
        for i, word in enumerate(self.words_set):
            self.word2idx[word] = i + 3
        self.idx2word = [0] * len(self.word2idx)
        for w, i in self.word2idx.items():
            self.idx2word[i] = w

        # extract pairs of lines in a conversation (s0, s1, s2) -> {(s0, s1), (s1, s2)}
        utt_pairs = []
        for line in conversations:
            parts = map(lambda x: x[1:-1], map(lambda x: x.strip(), line.lower().split(self.FILE_SEP))[-1][1:-1].split(', '))
            utt_pairs += list(pairwise(parts))
        utt_pairs = np.random.permutation(utt_pairs)
        train_utt_pairs = utt_pairs[self.VAL_COUNT:]
        self.val_pairs = utt_pairs[:self.VAL_COUNT]

        def find_bucket(enc_size, dec_size, buckets):
            return next(dropwhile(lambda x: enc_size > x[0] or dec_size > x[1], buckets), None)

        for pair in train_utt_pairs:
            bckt = find_bucket(len(self.lines_dict[pair[0]]), len(self.lines_dict[pair[1]]), self.bucket_sizes)
            if bckt is None:
                self.bucket_pairs[(-1, -1)].append(pair)
            else:
                self.bucket_pairs[bckt].append(pair)

        self.bucket_ordering = []
        for bckt, _ in sorted(map(lambda x: (x[0], len(x[1])), self.bucket_pairs.items()), key=lambda x: x[1], reverse=True):
            self.bucket_ordering.append(bckt)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号