train_bts.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Msc_Multi_label_ZeroShot 作者: thomasSve 项目源码 文件源码
def define_network(vector_size, loss):
    base_model = InceptionV3(weights='imagenet', include_top=True)

    for layer in base_model.layers: # Freeze layers in pretrained model
        layer.trainable = False

    # fully-connected layer to predict 
    x = Dense(4096, activation='relu', name='fc1')(base_model.layers[-2].output)
    x = Dense(8096, activation='relu', name='fc2')(x)
    x = Dropout(0.5)(x)
    x = Dense(2048,activation='relu', name='fc3')(x)
    predictions = Dense(vector_size, activation='relu')(x)
    l2 = Lambda(lambda x: K.l2_normalize(x, axis=1))(predictions)
    model = Model(inputs=base_model.inputs, outputs=l2)

    optimizer = 'adam'
    if loss == 'euclidean':
        model.compile(optimizer = optimizer, loss = euclidean_distance)
    else:
        model.compile(optimizer = optimizer, loss = loss)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号