tagger_data.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:deep_srl 作者: luheng 项目源码 文件源码
def get_training_data(self, include_last_batch=False):
    """ Get shuffled training samples. Called at the beginning of each epoch.
    """
    # TODO: Speed up: Use variable size batches (different max length).  
    train_ids = range(len(self.train_sents))
    random.shuffle(train_ids)

    if not include_last_batch:
      num_batches = len(train_ids) // self.batch_size
      train_ids = train_ids[:num_batches * self.batch_size]

    num_samples = len(self.train_sents)
    tensors = [self.train_tensors[t] for t in train_ids]
    batched_tensors = [tensors[i: min(i+self.batch_size, num_samples)]
               for i in xrange(0, num_samples, self.batch_size)]
    results = [zip(*t) for t in batched_tensors]

    print("Extracted {} samples and {} batches.".format(num_samples, len(batched_tensors)))
    return results
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号