def __call__(self, shape, dtype=None, partition_info=None):
# Creating different RestoreV2 ops when a single one could
# output several tensors seems inefficient, but that's actually
# what tf.Saver.restore_op (via tf.BaseSaverBuilder) does too.
if self._scope is None:
scope_name = tf.get_variable_scope().name
elif callable(self._scope):
scope_name = self._scope(tf.get_variable_scope().name)
else:
scope_name = self._scope
tensor_name = self._var_name
if scope_name:
tensor_name = '{}/{}'.format(scope_name, tensor_name)
tensor = io_ops.restore_v2(
self._filename,
[tensor_name],
[self._partition_spec(shape, partition_info)],
[dtype])[0]
tensor.set_shape(shape)
return tensor
# pylint: disable=invalid-name
评论列表
文章目录