def __init__(self, source, target,
source_dicts, target_dict,
batch_size=128,
maxlen=100,
n_words_source=-1,
n_words_target=-1,
shuffle_each_epoch=False,
sort_by_length=True,
indomain_source='', indomain_target='',
interpolation_rate=0.1,
maxibatch_size=20):
if shuffle_each_epoch:
shuffle.main([source, target])
shuffle.main([indomain_source, indomain_target])
self.source = fopen(source+'.shuf', 'r')
self.target = fopen(target+'.shuf', 'r')
self.indomain_source = fopen(indomain_source+'.shuf', 'r')
self.indomain_target = fopen(indomain_target+'.shuf', 'r')
else:
self.source = fopen(source, 'r')
self.target = fopen(target, 'r')
self.indomain_source = fopen(indomain_source, 'r')
self.indomain_target = fopen(indomain_target, 'r')
self.source_dicts = []
for source_dict in source_dicts:
self.source_dicts.append(load_dict(source_dict))
self.target_dict = load_dict(target_dict)
self.batch_size = batch_size
self.maxlen = maxlen
self.n_words_source = n_words_source
self.n_words_target = n_words_target
if self.n_words_source > 0:
for d in self.source_dicts:
for key, idx in d.items():
if idx >= self.n_words_source:
del d[key]
if self.n_words_target > 0:
for key, idx in self.target_dict.items():
if idx >= self.n_words_target:
del self.target_dict[key]
self.shuffle = shuffle_each_epoch
self.sort_by_length = sort_by_length
self.source_buffer = []
self.target_buffer = []
self.k = batch_size * maxibatch_size
self.end_of_data = False
self.interpolation_rate = interpolation_rate
self.indomain_k = int(math.ceil(self.interpolation_rate * self.k))
self.outdomain_k = self.k - self.indomain_k
domain_interpolation_data_iterator.py 文件源码
python
阅读 40
收藏 0
点赞 0
评论 0
评论列表
文章目录