batch_updater.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:RL4Data 作者: fyabc 项目源码 文件源码
def filter_batch(self, batch_index, *args):
        selected_number = self.cost_threshold(self.iteration)

        selected_batch_data = [data[batch_index] for data in self.all_data]
        selected_batch_data = self.prepare_data(*selected_batch_data)

        targets = selected_batch_data[-1]

        cost_list = self.model.f_cost_list_without_decay(*selected_batch_data)
        label_cost_lists = [cost_list[targets == label] for label in range(self.model.output_size)]

        result = []

        for i, label_cost_list in enumerate(label_cost_lists):
            if label_cost_list.size != 0:
                threshold = heapq.nsmallest(selected_number, label_cost_list)[-1]
                for j in range(len(targets)):
                    if targets[j] == i and cost_list[j] <= threshold:
                        result.append(batch_index[j])
                        if Config['temp_job'] == 'log_data':
                            self.add_index(batch_index[j], cost_list[j])

        return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号