train.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:DeepLearning-OCR 作者: xingjian-f 项目源码 文件源码
def train(model, batch_size, nb_epoch, save_dir, train_data, val_data, char_set):
    X_train, y_train = train_data[0], train_data[1]
    sample_weight = get_sample_weight(y_train, char_set)
    print 'X_train shape:', X_train.shape
    print X_train.shape[0], 'train samples'
    if os.path.exists(save_dir) == False:
        os.mkdir(save_dir)

    start_time = time.time()
    save_path = save_dir + 'weights.{epoch:02d}-{val_loss:.2f}.hdf5'
    check_pointer = ModelCheckpoint(save_path, 
        save_best_only=True)
    history = model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch, 
        validation_data=val_data,
        validation_split=0.1, 
        callbacks=[check_pointer],
        sample_weight=sample_weight
        )

    plot_loss_figure(history, save_dir + str(datetime.now()).split('.')[0].split()[1]+'.jpg')
    print 'Training time(h):', (time.time()-start_time) / 3600
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号