def setup_data(self, path):
"""Read and iteratively yield data to an agent."""
print('loading: ' + path)
questions = []
y = []
# open data file with labels
# (path will be provided to setup_data from opt['datafile'] defined above)
with open(path) as labels_file:
tsv_reader = csv.reader(labels_file, delimiter='\t')
for row in tsv_reader:
if len(row) != 3:
print('Warn: expected 3 columns in a tsv row, got ' + str(row))
continue
y.append(['??' if row[0] == '1' else '???'])
questions.append(row[1] + '\n' + row[2])
episode_done = True
if not y:
y = [None for _ in range(len(questions))]
indexes = range(len(questions))
if self.datatype_strict != 'test':
random_state = random.getstate()
random.setstate(self.random_state)
kf_seed = random.randrange(500000)
kf = KFold(self.opt.get('bagging_folds_number'), shuffle=True,
random_state=kf_seed)
i = 0
for train_index, test_index in kf.split(questions):
indexes = train_index if self.datatype_strict == 'train' else test_index
if i >= self.opt.get('bagging_fold_index', 0):
break
self.random_state = random.getstate()
random.setstate(random_state)
# define iterator over all queries
for i in indexes:
# get current label, both as a digit and as a text
# yield tuple with information and episode_done? flag
yield (self.question + "\n" + questions[i], y[i]), episode_done
评论列表
文章目录