def train(self, run_name, start_epoch, stop_epoch, img_w):
words_per_epoch = 16000
val_split = 0.2 #0.2
#val_words = len(val_crop_iter)
val_words = int(words_per_epoch * (val_split))
#fdir ='/Users/sofwath/Desktop/dhivehiocr/tmp/'
fdir = DATA_DIR
self.img_gen = TextImageGenerator(monogram_file=os.path.join(fdir, 'wordlist_mono_dhivehi.txt'),
bigram_file=os.path.join(fdir, 'wordlist_bi_dhivehi.txt'),
minibatch_size=32,
img_w=img_w,
img_h=self.img_h,
downsample_factor=(self.pool_size ** 2),
val_split=self.words_per_epoch - self.val_words
)
adam = keras.optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
output_dir = os.path.join(OUTPUT_DIR, run_name)
labels = Input(name='the_labels', shape=[self.img_gen.absolute_max_string_len], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
y_pred = Activation('softmax', name='softmax')(self.inner)
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])
self.model = Model(inputs=[self.input_data, labels, input_length, label_length], outputs=loss_out)
self.model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=adam)
if start_epoch > 0:
weight_file = os.path.join(OUTPUT_DIR, os.path.join(run_name, 'weights%02d.h5' % (start_epoch - 1)))
self.model.load_weights(weight_file)
viz_cb = VizCallback(run_name, self.test_func, self.img_gen.next_val(),self.model)
# use this (uncomment below) for test runs
#self.model.fit_generator(generator=self.img_gen.next_train(), steps_per_epoch=(self.words_per_epoch - self.val_words)/self.minibatch_size,
# epochs=stop_epoch, validation_data=self.img_gen.next_val(), validation_steps=self.val_words/self.minibatch_size,
# callbacks=[viz_cb, self.img_gen], initial_epoch=start_epoch, verbose=1)
self.model.fit_generator(generator=self.img_gen.next_train(), steps_per_epoch=(self.words_per_epoch - self.val_words),
epochs=stop_epoch, validation_data=self.img_gen.next_val(), validation_steps=self.val_words,
callbacks=[viz_cb, self.img_gen], initial_epoch=start_epoch, verbose=1)
评论列表
文章目录