grl_train.py 文件源码

python
阅读 14 收藏 0 点赞 0 评论 0

项目:Deep-Reinforcement-Learning-for-Dialogue-Generation-in-tensorflow 作者: liuyuemaicha 项目源码 文件源码
def read_data(config, source_path, target_path, max_size=None):
    data_set = [[] for _ in config.buckets]
    with gfile.GFile(source_path, mode="r") as source_file:
        with gfile.GFile(target_path, mode="r") as target_file:
            source, target = source_file.readline(), target_file.readline()
            counter = 0
            while source and target and (not max_size or counter < max_size):
                counter += 1
                if counter % 100000 == 0:
                    print("reading data line %d" % counter)
                    sys.stdout.flush()
                source_ids = [int(x) for x in source.strip().split()]
                target_ids = [int(x) for x in target.strip().split()]
                target_ids.append(data_utils.EOS_ID)
                for bucket_id, (source_size, target_size) in enumerate(config.buckets):
                    if len(source_ids) < source_size and len(target_ids) < target_size:
                        data_set[bucket_id].append([source_ids, target_ids])
                        break
                source, target = source_file.readline(), target_file.readline()
    return data_set
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号