def load_vgg_weight(self, model):
# Loading VGG 16 weights
if K.image_dim_ordering() == "th":
weights = get_file('vgg16_weights_th_dim_ordering_th_kernels_notop.h5', THEANO_WEIGHTS_PATH_NO_TOP,
cache_subdir='models')
else:
weights = get_file('vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', TF_WEIGHTS_PATH_NO_TOP,
cache_subdir='models')
f = h5py.File(weights)
layer_names = [name for name in f.attrs['layer_names']]
if self.vgg_layers is None:
self.vgg_layers = [layer for layer in model.layers
if 'vgg_' in layer.name]
for i, layer in enumerate(self.vgg_layers):
g = f[layer_names[i]]
weights = [g[name] for name in g.attrs['weight_names']]
layer.set_weights(weights)
# Freeze all VGG layers
for layer in self.vgg_layers:
layer.trainable = False
return model
models.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录