memnet.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:fathom 作者: rdadolf 项目源码 文件源码
def load_data(self):
    # single babi task
    # TODO: refactor all this running elsewhere
    # task data
    train, test = load_task(data_dir, task_id)

    vocab = sorted(reduce(lambda x, y: x | y, (set(list(chain.from_iterable(s)) + q + a) for s, q, a in train + test)))
    word_idx = dict((c, i + 1) for i, c in enumerate(vocab))

    self.memory_size = 50

    self.max_story_size = max(map(len, (s for s, _, _ in train + test)))
    self.mean_story_size = int(np.mean(map(len, (s for s, _, _ in train + test))))
    self.sentence_size = max(map(len, chain.from_iterable(s for s, _, _ in train + test)))
    self.query_size = max(map(len, (q for _, q, _ in train + test)))
    self.memory_size = min(self.memory_size, self.max_story_size)
    self.vocab_size = len(word_idx) + 1 # +1 for nil word
    self.sentence_size = max(self.query_size, self.sentence_size) # for the position

    print("Longest sentence length", self.sentence_size)
    print("Longest story length", self.max_story_size)
    print("Average story length", self.mean_story_size)

    # train/validation/test sets
    self.S, self.Q, self.A = vectorize_data(train, word_idx, self.sentence_size, self.memory_size)
    self.trainS, self.valS, self.trainQ, self.valQ, self.trainA, self.valA = cross_validation.train_test_split(self.S, self.Q, self.A, test_size=.1) # TODO: randomstate
    self.testS, self.testQ, self.testA = vectorize_data(test, word_idx, self.sentence_size, self.memory_size)

    print(self.testS[0])

    print("Training set shape", self.trainS.shape)

    # params
    self.n_train = self.trainS.shape[0]
    self.n_test = self.testS.shape[0]
    self.n_val = self.valS.shape[0]

    print("Training Size", self.n_train)
    print("Validation Size", self.n_val)
    print("Testing Size", self.n_test)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号