def start_keras(logger, job_backend):
if 'KERAS_BACKEND' not in os.environ:
os.environ['KERAS_BACKEND'] = 'tensorflow'
from . import keras_model_utils
# we need to import keras here, so we know which backend is used (and whether GPU is used)
os.chdir(job_backend.git.work_tree)
logger.debug("Start simple model")
# we use the source from the job commit directly
with job_backend.git.batch_commit('Git Version'):
job_backend.set_system_info('git_remote_url', job_backend.git.get_remote_url('origin'))
job_backend.set_system_info('git_version', job_backend.git.job_id)
# all our shapes are Tensorflow schema. (height, width, channels)
import keras.backend
if hasattr(keras.backend, 'set_image_dim_ordering'):
keras.backend.set_image_dim_ordering('tf')
if hasattr(keras.backend, 'set_image_data_format'):
keras.backend.set_image_data_format('channels_last')
from .KerasCallback import KerasCallback
trainer = Trainer(job_backend)
keras_logger = KerasCallback(job_backend, job_backend.logger)
job_backend.progress(0, job_backend.job['config']['epochs'])
logger.info("Start training")
keras_model_utils.job_start(job_backend, trainer, keras_logger)
job_backend.done()
评论列表
文章目录