def train():
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
logging.info('... building model')
sgd = SGD(lr=_sgd_lr, decay=_sgd_decay, momentum=0.9, nesterov=True)
model = resnet()
model.compile(
loss=_objective,
optimizer=sgd,
metrics=['mae'])
logging.info('... loading data')
X, Y = load_train_data()
logging.info('... training')
datagen = ImageDataGenerator(
# data augmentation
width_shift_range = 1./8.,
height_shift_range = 1./8.,
rotation_range = 0.,
shear_range = 0.,
zoom_range = 0.,
)
model.fit_generator(
datagen.flow(X, Y, batch_size=_batch_size),
samples_per_epoch=X.shape[0],
nb_epoch=_nb_epoch,
verbose=1)
return model
trainer.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录