def testInitFromPartitionVar(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
v1 = _create_partition_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable(
name="my1",
shape=[100, 100],
initializer=init_ops.truncated_normal_initializer(0.5),
partitioner=partitioned_variables.min_max_variable_partitioner(
max_partitions=5, axis=0, min_slice_size=8 << 10))
my1_var_list = my1._get_variable_list()
with variable_scope.variable_scope("some_other_scope"):
my2 = variable_scope.get_variable(
name="var1",
shape=[100, 100],
initializer=init_ops.truncated_normal_initializer(0.5),
partitioner=partitioned_variables.min_max_variable_partitioner(
max_partitions=5, axis=0, min_slice_size=8 << 10))
my2_var_list = my2._get_variable_list()
checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
"scope/var1": "some_scope/my1",
"scope/": "some_other_scope/"})
session.run(variables.global_variables_initializer())
my1_values = session.run(my1_var_list)
self.assertAllEqual(my1_values, v1)
my2_values = session.run(my2_var_list)
self.assertAllEqual(my2_values, v1)
# New graph and session.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable(
name="my1",
shape=[100, 100],
initializer=init_ops.truncated_normal_initializer(0.5),
partitioner=partitioned_variables.min_max_variable_partitioner(
max_partitions=5, axis=0, min_slice_size=8 << 10))
my1_var_list = my1._get_variable_list()
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"scope/var1": my1_var_list,})
session.run(variables.global_variables_initializer())
my1_values = session.run(my1_var_list)
self.assertAllEqual(my1_values, v1)
checkpoint_utils_test.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录