def on_epoch_end(self, epoch, logs={}):
font = {'family': 'Thaana Unicode Akeh',
'color': 'darkred',
'weight': 'normal',
'size': 12,
}
self.model.save_weights(os.path.join(self.output_dir, 'weights%02d.h5' % (epoch)))
#self.model.save('model.h5') # creates a HDF5 file
self.show_edit_distance(256)
self.num_display_words = 6
word_batch = next(self.text_img_gen)[0]
# add visual augmentations
res = decode_batch(self.test_func, word_batch['the_input'][0:6])
if word_batch['the_input'][0].shape[0] < 256:
cols = 2
else:
cols = 1
for i in range(self.num_display_words):
plt.subplot(self.num_display_words // cols, cols, i + 1)
if K.image_data_format() == 'channels_first':
the_input = word_batch['the_input'][i, 0, :, :]
else:
the_input = word_batch['the_input'][i, :, :, 0]
plt.imshow(the_input.T, cmap='Greys_r')
plt.xlabel('Actual = \'%s\'\Prediction = \'%s\'' % (word_batch['source_str'][i][::-1], res[i][::-1]),fontdict=font)
#print (('Actual = \'%s\'\Prediction = \'%s\'' % (word_batch['source_str'][i], res[i])))
fig = plt.gcf()
fig.set_size_inches(10, 13)
plt.savefig(os.path.join(self.output_dir, 'e%02d.png' % (epoch)))
plt.close()
评论列表
文章目录