def soft_copy_param(target_link, source_link, tau):
"""Soft-copy parameters of a link to another link."""
target_params = dict(target_link.namedparams())
for param_name, param in source_link.namedparams():
target_params[param_name].data[:] *= (1 - tau)
target_params[param_name].data[:] += tau * param.data
# Soft-copy Batch Normalization's statistics
target_links = dict(target_link.namedlinks())
for link_name, link in source_link.namedlinks():
if isinstance(link, L.BatchNormalization):
target_bn = target_links[link_name]
target_bn.avg_mean[:] *= (1 - tau)
target_bn.avg_mean[:] += tau * link.avg_mean
target_bn.avg_var[:] *= (1 - tau)
target_bn.avg_var[:] += tau * link.avg_var
评论列表
文章目录