thaanaocr.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:thaanaOCR 作者: Sofwath 项目源码 文件源码
def train(self, run_name, start_epoch, stop_epoch, img_w):

        words_per_epoch = 16000
        val_split = 0.2 #0.2
        #val_words = len(val_crop_iter)
        val_words = int(words_per_epoch * (val_split))


        #fdir ='/Users/sofwath/Desktop/dhivehiocr/tmp/'
        fdir = DATA_DIR
        self.img_gen = TextImageGenerator(monogram_file=os.path.join(fdir, 'wordlist_mono_dhivehi.txt'),
                                     bigram_file=os.path.join(fdir, 'wordlist_bi_dhivehi.txt'),
                                     minibatch_size=32,
                                     img_w=img_w,
                                     img_h=self.img_h,
                                     downsample_factor=(self.pool_size ** 2),
                                     val_split=self.words_per_epoch - self.val_words
                                     )

        adam = keras.optimizers.Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

        output_dir = os.path.join(OUTPUT_DIR, run_name)


        labels = Input(name='the_labels', shape=[self.img_gen.absolute_max_string_len], dtype='float32')
        input_length = Input(name='input_length', shape=[1], dtype='int64')
        label_length = Input(name='label_length', shape=[1], dtype='int64')

        y_pred = Activation('softmax', name='softmax')(self.inner)

        loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])

        self.model = Model(inputs=[self.input_data, labels, input_length, label_length], outputs=loss_out)

        self.model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=adam)

        if start_epoch > 0:
            weight_file = os.path.join(OUTPUT_DIR, os.path.join(run_name, 'weights%02d.h5' % (start_epoch - 1)))
            self.model.load_weights(weight_file)


        viz_cb = VizCallback(run_name, self.test_func, self.img_gen.next_val(),self.model)


        # use this (uncomment below) for test runs 

        #self.model.fit_generator(generator=self.img_gen.next_train(), steps_per_epoch=(self.words_per_epoch - self.val_words)/self.minibatch_size,
        #                    epochs=stop_epoch, validation_data=self.img_gen.next_val(), validation_steps=self.val_words/self.minibatch_size,
        #                    callbacks=[viz_cb, self.img_gen], initial_epoch=start_epoch, verbose=1)



        self.model.fit_generator(generator=self.img_gen.next_train(), steps_per_epoch=(self.words_per_epoch - self.val_words),
                        epochs=stop_epoch, validation_data=self.img_gen.next_val(), validation_steps=self.val_words,
                        callbacks=[viz_cb, self.img_gen], initial_epoch=start_epoch, verbose=1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号