def test_remap_var_list(self):
# Get a test `var_list` {var.name: var}
var_list = {var.op.name: var for var in tf.global_variables()}
# Specify mapping from old var names to new ones.
mapping = {'model_0/Weights': 'model_0/Filters'}
self.dbinterface.load_param_dict = mapping
# Perform the mapping.
mapped_vars = self.dbinterface.remap_var_list(var_list)
# Confirm that the mapping has been done correctly.
for name, var in mapped_vars.items():
self.log.info('{} mapped to {}'.format(name, var.op.name))
if name == 'model_0/Filters':
self.assertEqual(name, mapping[var.op.name])
评论列表
文章目录