callbacks.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:torchsample 作者: ncullen93 项目源码 文件源码
def on_train_end(self, logs=None):
        REJECT_KEYS={'has_validation_data'}
        row_dict = self.row_dict

        class CustomDialect(csv.excel):
            delimiter = self.sep
        self.keys = self.keys
        temp_file = NamedTemporaryFile(delete=False, mode='w')
        with open(self.file, 'r') as csv_file, temp_file:
            reader = csv.DictReader(csv_file,
                fieldnames=['model'] + [k for k in self.keys if k not in REJECT_KEYS], 
                dialect=CustomDialect)
            writer = csv.DictWriter(temp_file,
                fieldnames=['model'] + [k for k in self.keys if k not in REJECT_KEYS], 
                dialect=CustomDialect)
            for row_idx, row in enumerate(reader):
                if row_idx == 0:
                    # re-write header with on_train_end's metrics
                    pass
                if row['model'] == self.row_dict['model']:
                    writer.writerow(row_dict)
                else:
                    writer.writerow(row)
        shutil.move(temp_file.name, self.file)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号