def main():
# img_width, img_height = 48, 48
img_width, img_height = 200, 60
img_channels = 1
# batch_size = 1024
batch_size = 32
nb_epoch = 1000
post_correction = False
save_dir = 'save_model/' + str(datetime.now()).split('.')[0].split()[0] + '/' # model is saved corresponding to the datetime
train_data_dir = 'train_data/ip_train/'
# train_data_dir = 'train_data/single_1000000/'
val_data_dir = 'train_data/ip_val/'
test_data_dir = 'test_data//'
weights_file_path = 'save_model/2016-10-27/weights.11-1.58.hdf5'
char_set, char2idx = get_char_set(train_data_dir)
nb_classes = len(char_set)
max_nb_char = get_maxnb_char(train_data_dir)
label_set = get_label_set(train_data_dir)
# val 'char_set:', char_set
print 'nb_classes:', nb_classes
print 'max_nb_char:', max_nb_char
print 'size_label_set:', len(label_set)
model = build_shallow(img_channels, img_width, img_height, max_nb_char, nb_classes) # build CNN architecture
# model.load_weights(weights_file_path) # load trained model
val_data = load_data(val_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
# val_data = None
train_data = load_data(train_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
train(model, batch_size, nb_epoch, save_dir, train_data, val_data, char_set)
# train_data = load_data(train_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
# test(model, train_data, char_set, label_set, post_correction)
# val_data = load_data(val_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
# test(model, val_data, char_set, label_set, post_correction)
# test_data = load_data(test_data_dir, max_nb_char, img_width, img_height, img_channels, char_set, char2idx)
# test(model, test_data, char_set, label_set, post_correction)
评论列表
文章目录