models.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:auto-triage 作者: zhijian-liu 项目源码 文件源码
def feature_extractor(FLAGS, suffix = ""):
  weights = FLAGS.weights if FLAGS.weights != "random" else None

  if FLAGS.model == "vgg16":
    from keras.applications.vgg16 import VGG16
    feature_extractor = VGG16(weights = weights)
    remove_last_layer(feature_extractor)
  elif FLAGS.model == "vgg19":
    from keras.applications.vgg19 import VGG19
    feature_extractor = VGG19(weights = weights)
    remove_last_layer(feature_extractor)
  elif FLAGS.model == "resnet50":
    from keras.applications.resnet50 import ResNet50
    feature_extractor = ResNet50(weights = weights)
    remove_last_layer(feature_extractor)
  else:
    raise NotImplementedError

  feature_extractor.name = FLAGS.model + suffix

  if FLAGS.regularizer == "l2":
    add_regularizer(feature_extractor)
  elif FLAGS.regularizer != "none":
    raise NotImplementedError
  return feature_extractor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号