main_bnum.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:VGG 作者: jackfan00 项目源码 文件源码
def VGGregionModel(inputshape):
    input_tensor = Input(shape=inputshape) #(448, 448, 3))
    vgg_model = VGG16(input_tensor=input_tensor, weights='imagenet', include_top=False )

    # add region detection layers
    x = vgg_model.output
    x = Flatten()(x)
    #x = Dense(256, activation='relu')(x)
    #x = Dense(2048, activation='relu')(x)
    #x = Dropout(0.5)(x)
    x = Dense((cfgconst.side**2)*(cfgconst.classes+5)*cfgconst.bnum)(x)

    model = Model(input=vgg_model.input, output=x)
    #
    print 'returned model:'
    index = 0
    for l in model.layers:
        if index <= (18-8):
            l.trainable = False
                #print l.name+' '+str(l.input_shape)+' -> '+str(l.output_shape)+', trainable:'+str(l.trainable)
        index = index + 1

    return model

#
# pretrained
#model = VGGregionModel((448, 448, 3) )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号