def train(model, batch_size, nb_epoch, save_dir, train_data, val_data, char_set):
X_train, y_train = train_data[0], train_data[1]
sample_weight = get_sample_weight(y_train, char_set)
print 'X_train shape:', X_train.shape
print X_train.shape[0], 'train samples'
if os.path.exists(save_dir) == False:
os.mkdir(save_dir)
start_time = time.time()
save_path = save_dir + 'weights.{epoch:02d}-{val_loss:.2f}.hdf5'
check_pointer = ModelCheckpoint(save_path,
save_best_only=True)
history = model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=nb_epoch,
validation_data=val_data,
validation_split=0.1,
callbacks=[check_pointer],
sample_weight=sample_weight
)
plot_loss_figure(history, save_dir + str(datetime.now()).split('.')[0].split()[1]+'.jpg')
print 'Training time(h):', (time.time()-start_time) / 3600
评论列表
文章目录