def copy_weights(src_model, dst_model, must_exist=True):
"""Copy weights from `src_model` to `dst_model`.
Parameters
----------
src_model
Keras source model.
dst_model
Keras destination model.
must_exist: bool
If `True`, raises `ValueError` if a layer in `dst_model` does not exist
in `src_model`.
Returns
-------
list
Names of layers that were copied.
"""
copied = []
for dst_layer in dst_model.layers:
for src_layer in src_model.layers:
if src_layer.name == dst_layer.name:
break
if not src_layer:
if must_exist:
tmp = 'Layer "%s" not found!' % (src_layer.name)
raise ValueError(tmp)
else:
continue
dst_layer.set_weights(src_layer.get_weights())
copied.append(dst_layer.name)
return copied
评论列表
文章目录