def assign_from_pkl(self, pkl_path):
with open(pkl_path, 'rb') as f:
load_variables = pickle.load(f)
uninitialized_vars = []
for i, variable in enumerate(tf.global_variables()):
# 0 -41
# 42-77 + 10
# 78-117+ 20
if i<=41:
idx = i
elif i<=77:
idx = i + 10
elif i<=117:
idx = i + 20
else:
uninitialized_vars.append(variable)
continue
variable_shape = load_variables[idx].shape
if len(variable_shape) == 1:
load_variable = load_variables[idx]
elif len(variable_shape) == 4:
load_variable = np.transpose(load_variables[idx], [3, 2, 1, 0])
elif len(variable_shape) == 3:
load_variable = np.transpose(load_variables[idx], [2, 1, 0])
else:
assert False
print variable.name, variable.get_shape(), load_variable.shape
variable.assign(load_variable).op.run()
pdb.set_trace()
tf.initialize_variables(uninitialized_vars).op.run()
return
评论列表
文章目录