def main():
RUN_TIME = sys.argv[1]
if RUN_TIME == "TRAIN":
image_features = Input(shape=(4096,))
model = build_model(image_features)
print model.summary()
# number of training images
_num_train = get_num_train_images()
# Callbacks
# remote_cb = RemoteMonitor(root='http://localhost:9000')
tensorboard = TensorBoard(log_dir="logs/{}".format(time()))
epoch_cb = EpochCheckpoint(folder="./snapshots/")
valid_cb = ValidCallBack()
# fit generator
steps_per_epoch = math.ceil(_num_train/float(BATCH))
print "Steps per epoch i.e number of iterations: ",steps_per_epoch
train_datagen = data_generator(batch_size=INCORRECT_BATCH, image_class_ranges=TRAINING_CLASS_RANGES)
history = model.fit_generator(
train_datagen,
steps_per_epoch=steps_per_epoch,
epochs=250,
callbacks=[tensorboard, valid_cb]
)
print history.history.keys()
elif RUN_TIME == "TEST":
from keras.models import load_model
model = load_model("snapshots/epoch_49.hdf5", custom_objects={"hinge_rank_loss":hinge_rank_loss})
K.clear_session()
评论列表
文章目录