def on_train_end(self, logs=None):
REJECT_KEYS={'has_validation_data'}
row_dict = self.row_dict
class CustomDialect(csv.excel):
delimiter = self.sep
self.keys = self.keys
temp_file = NamedTemporaryFile(delete=False, mode='w')
with open(self.file, 'r') as csv_file, temp_file:
reader = csv.DictReader(csv_file,
fieldnames=['model'] + [k for k in self.keys if k not in REJECT_KEYS],
dialect=CustomDialect)
writer = csv.DictWriter(temp_file,
fieldnames=['model'] + [k for k in self.keys if k not in REJECT_KEYS],
dialect=CustomDialect)
for row_idx, row in enumerate(reader):
if row_idx == 0:
# re-write header with on_train_end's metrics
pass
if row['model'] == self.row_dict['model']:
writer.writerow(row_dict)
else:
writer.writerow(row)
shutil.move(temp_file.name, self.file)
评论列表
文章目录