def read_data(config, source_path, target_path, max_size=None):
data_set = [[] for _ in config.buckets]
with gfile.GFile(source_path, mode="r") as source_file:
with gfile.GFile(target_path, mode="r") as target_file:
source, target = source_file.readline(), target_file.readline()
counter = 0
while source and target and (not max_size or counter < max_size):
counter += 1
if counter % 100000 == 0:
print("reading data line %d" % counter)
sys.stdout.flush()
source_ids = [int(x) for x in source.strip().split()]
target_ids = [int(x) for x in target.strip().split()]
target_ids.append(data_utils.EOS_ID)
for bucket_id, (source_size, target_size) in enumerate(config.buckets):
if len(source_ids) < source_size and len(target_ids) < target_size:
data_set[bucket_id].append([source_ids, target_ids])
break
source, target = source_file.readline(), target_file.readline()
return data_set
grl_train.py 文件源码
python
阅读 14
收藏 0
点赞 0
评论 0
评论列表
文章目录