def filter_batch(self, batch_index, *args):
selected_number = self.cost_threshold(self.iteration)
selected_batch_data = [data[batch_index] for data in self.all_data]
selected_batch_data = self.prepare_data(*selected_batch_data)
targets = selected_batch_data[-1]
cost_list = self.model.f_cost_list_without_decay(*selected_batch_data)
label_cost_lists = [cost_list[targets == label] for label in range(self.model.output_size)]
result = []
for i, label_cost_list in enumerate(label_cost_lists):
if label_cost_list.size != 0:
threshold = heapq.nsmallest(selected_number, label_cost_list)[-1]
for j in range(len(targets)):
if targets[j] == i and cost_list[j] <= threshold:
result.append(batch_index[j])
if Config['temp_job'] == 'log_data':
self.add_index(batch_index[j], cost_list[j])
return result
评论列表
文章目录