cnn.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:dogsVScats 作者: prajwalkr 项目源码 文件源码
def runner(model, epochs):
    initial_LR = 0.001
    if not use_multiscale and not use_multicrop: training_gen, val_gen = DataGen()
    else: training_gen, val_gen = ms_traingen(), ms_valgen()

    model.compile(optimizer=SGD(initial_LR, momentum=0.9, nesterov=True), loss='binary_crossentropy')

    val_checkpoint = ModelCheckpoint('bestval.h5','val_loss',1, True)
    cur_checkpoint = ModelCheckpoint('current.h5')
    # def lrForEpoch(i): return initial_LR
    lrScheduler = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, cooldown=1, verbose=1)
    print 'Model compiled.'

    try:
        model.fit_generator(training_gen,samples_per_epoch,epochs,
                        verbose=1,validation_data=val_gen,nb_val_samples=nb_val_samples,
                        callbacks=[val_checkpoint, cur_checkpoint, lrScheduler])
    except Exception as e:
        print e
    finally:
        fname = dumper(model,'cnn')
        print 'Model saved to disk at {}'.format(fname)
        return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号