model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:devise-keras 作者: priyamtejaswin 项目源码 文件源码
def main():
    RUN_TIME = sys.argv[1]


    if RUN_TIME == "TRAIN":
        image_features = Input(shape=(4096,))
        model = build_model(image_features)
        print model.summary()

        # number of training images 
        _num_train = get_num_train_images()

        # Callbacks 
        # remote_cb = RemoteMonitor(root='http://localhost:9000')
        tensorboard = TensorBoard(log_dir="logs/{}".format(time()))
        epoch_cb    = EpochCheckpoint(folder="./snapshots/")
        valid_cb    = ValidCallBack()

        # fit generator
        steps_per_epoch = math.ceil(_num_train/float(BATCH))
        print "Steps per epoch i.e number of iterations: ",steps_per_epoch

        train_datagen = data_generator(batch_size=INCORRECT_BATCH, image_class_ranges=TRAINING_CLASS_RANGES)
        history = model.fit_generator(
                train_datagen,
                steps_per_epoch=steps_per_epoch,
                epochs=250,
                callbacks=[tensorboard, valid_cb]
            )
        print history.history.keys()


    elif RUN_TIME == "TEST":
        from keras.models import load_model 
        model = load_model("snapshots/epoch_49.hdf5", custom_objects={"hinge_rank_loss":hinge_rank_loss})

    K.clear_session()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号