def test_pickle_bug(self):
# Regression test for bug fixed in 24d4fd291054.
o = Prod()
s = pickle.dumps(o, protocol=-1)
o = pickle.loads(s)
pickle.dumps(o)
python类dumps()的实例源码
def test_pickle(self):
a = T.scalar() # the a is for 'anonymous' (un-named).
x, s = T.scalars('xs')
f = function([x, In(a, value=1.0, name='a'),
In(s, value=0.0, update=s + a * x, mutable=True)], s + a * x)
try:
# Note that here we also test protocol 0 on purpose, since it
# should work (even though one should not use it).
g = pickle.loads(pickle.dumps(f, protocol=0))
g = pickle.loads(pickle.dumps(f, protocol=-1))
except NotImplementedError as e:
if e[0].startswith('DebugMode is not picklable'):
return
else:
raise
# if they both return, assume that they return equivalent things.
# print [(k,id(k)) for k in f.finder.keys()]
# print [(k,id(k)) for k in g.finder.keys()]
self.assertFalse(g.container[0].storage is f.container[0].storage)
self.assertFalse(g.container[1].storage is f.container[1].storage)
self.assertFalse(g.container[2].storage is f.container[2].storage)
self.assertFalse(x in g.container)
self.assertFalse(x in g.value)
self.assertFalse(g.value[1] is f.value[1]) # should not have been copied
self.assertFalse(g.value[2] is f.value[2]) # should have been copied because it is mutable.
self.assertFalse((g.value[2] != f.value[2]).any()) # its contents should be identical
self.assertTrue(f(2, 1) == g(2)) # they should be in sync, default value should be copied.
self.assertTrue(f(2, 1) == g(2)) # they should be in sync, default value should be copied.
f(1, 2) # put them out of sync
self.assertFalse(f(1, 2) == g(1, 2)) # they should not be equal anymore.
def startup(self, recording_requester):
"""
Prepare for a new run and create/update the abs2prom and prom2abs variables.
Parameters
----------
recording_requester :
Object to which this recorder is attached.
"""
super(SqliteRecorder, self).startup(recording_requester)
# grab the system
if isinstance(recording_requester, Driver):
system = recording_requester._problem.model
elif isinstance(recording_requester, System):
system = recording_requester
else:
system = recording_requester._system
# merge current abs2prom and prom2abs with this system's version
for io in ['input', 'output']:
for v in system._var_abs2prom[io]:
self._abs2prom[io][v] = system._var_abs2prom[io][v]
for v in system._var_allprocs_prom2abs_list[io]:
if v not in self._prom2abs[io]:
self._prom2abs[io][v] = system._var_allprocs_prom2abs_list[io][v]
else:
self._prom2abs[io][v] = list(set(self._prom2abs[io][v]) |
set(system._var_allprocs_prom2abs_list[io][v]))
# store the updated abs2prom and prom2abs
abs2prom = pickle.dumps(self._abs2prom)
prom2abs = pickle.dumps(self._prom2abs)
if self._open_close_sqlite:
with self.con:
self.con.execute("UPDATE metadata SET abs2prom=?, prom2abs=?",
(abs2prom, prom2abs))
def record_metadata_driver(self, recording_requester):
"""
Record driver metadata.
Parameters
----------
recording_requester: <Driver>
The Driver that would like to record its metadata.
"""
driver_class = type(recording_requester).__name__
model_viewer_data = pickle.dumps(recording_requester._model_viewer_data,
pickle.HIGHEST_PROTOCOL)
with self.con:
self.con.execute("INSERT INTO driver_metadata(id, model_viewer_data) VALUES(?,?)",
(driver_class, sqlite3.Binary(model_viewer_data)))
def record_metadata_system(self, recording_requester):
"""
Record system metadata.
Parameters
----------
recording_requester: <System>
The System that would like to record its metadata.
"""
# Cannot handle PETScVector yet
from openmdao.api import PETScVector
if PETScVector and isinstance(recording_requester._outputs, PETScVector):
return # Cannot handle PETScVector yet
# collect scaling arrays
scaling_vecs = {}
for kind, odict in iteritems(recording_requester._vectors):
scaling_vecs[kind] = scaling = {}
for vecname, vec in iteritems(odict):
scaling[vecname] = vec._scaling
scaling_factors = pickle.dumps(scaling_vecs,
pickle.HIGHEST_PROTOCOL)
path = recording_requester.pathname
if not path:
path = 'root'
with self.con:
self.con.execute("INSERT INTO system_metadata(id, scaling_factors) \
VALUES(?,?)",
(path, sqlite3.Binary(scaling_factors)))
def record_metadata_driver(self, recording_requester):
"""
Record driver metadata.
Parameters
----------
recording_requester: <Driver>
The Driver that would like to record its metadata.
"""
driver_class = type(recording_requester).__name__
model_viewer_data = json.dumps(recording_requester._model_viewer_data)
self._record_driver_metadata(driver_class, model_viewer_data)
def test_pickle(self):
self.test_file_name_property()
name = "file"
file1 = os.path.join(self.tmp_dir, name)
wrap = FileWrapper(file1)
pickled_data = pickle.dumps(wrap)
wrap2 = pickle.loads(pickled_data)
print(wrap2.file_path)
def test_pickle(self):
rpm_version = [int(v) for v in getattr(rpm, '__version__', '0.0').split('.')]
if rpm_version[0:2] < [4, 10]:
warnings.warn('RPM header pickling unsupported in rpm %s' % rpm_version)
return
wrap = RpmWrapper(self.file_path)
pickled_data = pickle.dumps(wrap)
wrap2 = pickle.loads(pickled_data)
self.assertEqual(wrap.name, wrap2.name)
def test_pickle(self):
wrap = SimpleRpmWrapper(self.file_path)
pickled_data = pickle.dumps(wrap)
wrap2 = pickle.loads(pickled_data)
self.assertEqual(wrap.name, wrap2.name)
def __call__(self, *args, **kwargs):
# If the function args cannot be used as a cache hash key, fail fast
key = pickle.dumps((args, kwargs))
try:
return self.cache[key]
except KeyError:
value = self.func(*args, **kwargs)
self.cache[key] = value
return value
def _serialize_key(self, key):
return cPickle.dumps(key)
def __init__(self, data_desc, dtype=None,
batch_filter=None, batch_mode='batch',
ncpu=1, buffer_size=8, hwm=86,
mpi_backend='python'):
super(Feeder, self).__init__(data=as_tuple(data_desc, t=DataDescriptor),
read_only=True)
# find intersection of all indices in DataDescriptor
self._indices_keys = async(
lambda: np.array(
list(set.intersection(*[set(dat.indices.keys())
for dat in self._data])),
dtype=str)
)()
# ====== desire dtype ====== #
nb_data = sum(len(dat._data) for dat in self._data)
self._output_dtype = as_tuple(dtype, N=nb_data)
# ====== Set default recipes ====== #
self._recipes = RecipeList()
self._recipes.set_feeder_info(nb_desc=len(self._data))
self.set_multiprocessing(ncpu, buffer_size, hwm, mpi_backend)
# ====== cache shape information ====== #
# store first dimension
self._cache_shape = None
# if the recipes changed the shape need to be recalculated
self._recipes_changed = False
# ====== Iteration information ====== #
self._running_iter = []
# ====== batch mode ====== #
if batch_filter is None:
batch_filter = _dummy_batch_filter
elif not hasattr(batch_filter, '__call__'):
raise ValueError('batch_filter must be a function has 1 or 2 '
'parameters (X) or (X, y).')
# check if batch_filter Picklable
try:
cPickle.dumps(batch_filter, protocol=2)
except Exception:
raise ValueError("`batch_filter` must be pickle-able, which must be "
"top-level function.")
self._batch_filter = batch_filter
# check batch_mode
batch_mode = str(batch_mode).lower()
if batch_mode not in ("batch", 'file'):
raise ValueError("Only support `batch_mode`: 'file'; 'batch', but "
"given value: '%s'" % batch_mode)
self._batch_mode = batch_mode
# ==================== pickling ==================== #
def _flush(self, save_all=False):
"""
Parameters
----------
save_indices: bool
force the indices dictionary to be saved, even though,
its increased hasn't reach the maximum.
"""
# check if closed or in read only mode
if self.is_closed or self.read_only:
return
# ====== write new data ====== #
# get old position
file = self._file
# start from header (i.e. "mmapdict")
file.seek(len(MmapDict.HEADER))
max_position = int(file.read(MmapDict.SIZE_BYTES))
# ====== serialize the data ====== #
# start from old_max_position, append new values
file.seek(max_position)
for key, value in self._cache_dict.items():
try:
value = marshal.dumps(value)
except ValueError:
raise RuntimeError("Cannot marshal.dump %s" % str(value))
self.indices[key] = (max_position, len(value))
max_position += len(value)
file.write(value)
# increase indices size (in MegaBytes)
self._increased_indices_size += (8 + 8 + len(key)) / 1024. / 1024.
# ====== write the dumped indices ====== #
indices_length = 0
if save_all or \
self._increased_indices_size > MmapDict.MAX_INDICES_SIZE:
indices_dump = cPickle.dumps(self.indices,
protocol=cPickle.HIGHEST_PROTOCOL)
indices_length = len(indices_dump)
file.write(indices_dump)
self._increased_indices_size = 0.
# ====== update the position ====== #
# write new max size
file.seek(len(MmapDict.HEADER))
max_position = ('%' + str(MmapDict.SIZE_BYTES) + 'd') % max_position
file.write(max_position.encode())
# update length of pickled indices dictionary
if indices_length > 0:
indices_length = ('%' + str(MmapDict.SIZE_BYTES) + 'd') % indices_length
file.write(indices_length.encode())
# flush everything
file.flush()
# upate the mmap
self._mmap.close(); del self._mmap
self._mmap = mmap.mmap(file.fileno(), length=0, offset=0,
flags=mmap.MAP_SHARED)
# reset some values
del self._cache_dict
self._cache_dict = {}
# ==================== I/O methods ==================== #
def test_optimizations_preserved(self):
a = T.dvector() # the a is for 'anonymous' (un-named).
x = T.dvector('x')
s = T.dvector('s')
xm = T.dmatrix('x')
sm = T.dmatrix('s')
f = function([a, x, s, xm, sm], ((a.T.T) * (tensor.dot(xm, (sm.T.T.T)) + x).T * (x / x) + s))
old_default_mode = config.mode
old_default_opt = config.optimizer
old_default_link = config.linker
try:
try:
str_f = pickle.dumps(f, protocol=-1)
config.mode = 'Mode'
config.linker = 'py'
config.optimizer = 'None'
g = pickle.loads(str_f)
# print g.maker.mode
# print compile.mode.default_mode
except NotImplementedError as e:
if e[0].startswith('DebugMode is not pickl'):
g = 'ok'
finally:
config.mode = old_default_mode
config.optimizer = old_default_opt
config.linker = old_default_link
if g == 'ok':
return
assert f.maker is not g.maker
assert f.maker.fgraph is not g.maker.fgraph
tf = f.maker.fgraph.toposort()
tg = f.maker.fgraph.toposort()
assert len(tf) == len(tg)
for nf, ng in zip(tf, tg):
assert nf.op == ng.op
assert len(nf.inputs) == len(ng.inputs)
assert len(nf.outputs) == len(ng.outputs)
assert [i.type for i in nf.inputs] == [i.type for i in ng.inputs]
assert [i.type for i in nf.outputs] == [i.type for i in ng.outputs]
def __init__(self, token, case_name='Case Recording',
endpoint='http://www.openmdao.org/visualization', port='', case_id=None,
suppress_output=False):
"""
Initialize the OpenMDAOServerRecorder.
Parameters
----------
token: <string>
The token to be passed as a user's unique identifier. Register to get a token
at the given endpoint
case_name: <string>
The name this case should be stored under. Default: 'Case Recording'
endpoint: <string>
The URL (minus port, if not 80) where the server is hosted
port: <string>
The port which the server is listening on. Default to empty string (port 80)
suppress_output: <bool>
Indicates if all printing should be suppressed in this recorder
"""
super(WebRecorder, self).__init__()
self.model_viewer_data = None
self._headers = {'token': token, 'update': "False"}
if port != '':
self._endpoint = endpoint + ':' + port + '/case'
else:
self._endpoint = endpoint + '/case'
self._abs2prom = {'input': {}, 'output': {}}
self._prom2abs = {'input': {}, 'output': {}}
if case_id is None:
case_data_dict = {
'case_name': case_name,
'owner': 'temp_owner'
}
case_data = json.dumps(case_data_dict)
# Post case and get Case ID
case_request = requests.post(self._endpoint, data=case_data, headers=self._headers)
response = case_request.json()
if response['status'] != 'Failed':
self._case_id = str(response['case_id'])
else:
self._case_id = '-1'
if not suppress_output:
print("Failed to initialize case on server. No messages will be accepted \
from server for this case. Make sure you registered for a token at the \
given endpoint.")
if 'reasoning' in response:
if not suppress_output:
print("Failure reasoning: " + response['reasoning'])
else:
self._case_id = str(case_id)
self._headers['update'] = "True"
def _record_driver_iteration(self, counter, iteration_coordinate, success, msg,
desvars, responses, objectives, constraints, sysincludes):
"""
Record a driver iteration.
Parameters
----------
counter : int
The global counter associated with this iteration.
iteration_coordinate : str
The iteration coordinate to identify this iteration.
success : int
Integer to indicate success.
msg : str
The metadata message.
desvars : [JSON]
The array of json objects representing the design variables.
responses : [JSON]
The array of json objects representing the responses.
objectives : [JSON]
The array of json objects representing the objectives.
constraints : [JSON]
The array of json objects representing the constraints.
sysincludes : [JSON]
The array of json objects representing the system variables explicitly included
in the options.
"""
driver_iteration_dict = {
"counter": counter,
"iteration_coordinate": iteration_coordinate,
"success": success,
"msg": msg,
"desvars": [] if desvars is None else desvars,
"responses": [] if responses is None else responses,
"objectives": [] if objectives is None else objectives,
"constraints": [] if constraints is None else constraints,
"sysincludes": [] if sysincludes is None else sysincludes
}
global_iteration_dict = {
'record_type': 'driver',
'counter': counter
}
driver_iteration = json.dumps(driver_iteration_dict)
global_iteration = json.dumps(global_iteration_dict)
requests.post(self._endpoint + '/' + self._case_id + '/driver_iterations',
data=driver_iteration, headers=self._headers)
requests.post(self._endpoint + '/' + self._case_id + '/global_iterations',
data=global_iteration, headers=self._headers)
def _record_system_iteration(self, counter, iteration_coordinate, success, msg,
inputs, outputs, residuals):
"""
Record a system iteration.
Parameters
----------
counter : int
The global counter associated with this iteration.
iteration_coordinate : str
The iteration coordinate to identify this iteration.
success : int
Integer to indicate success.
msg : str
The metadata message.
inputs : [JSON]
The array of json objects representing the inputs.
outputs : [JSON]
The array of json objects representing the outputs.
residuals : [JSON]
The array of json objects representing the residuals.
"""
system_iteration_dict = {
'counter': counter,
'iteration_coordinate': iteration_coordinate,
'success': success,
'msg': msg,
'inputs': [] if inputs is None else inputs,
'outputs': [] if outputs is None else outputs,
'residuals': [] if residuals is None else residuals
}
global_iteration_dict = {
'record_type': 'system',
'counter': counter
}
system_iteration = json.dumps(system_iteration_dict)
global_iteration = json.dumps(global_iteration_dict)
requests.post(self._endpoint + '/' + self._case_id + '/system_iterations',
data=system_iteration, headers=self._headers)
requests.post(self._endpoint + '/' + self._case_id + '/global_iterations',
data=global_iteration, headers=self._headers)
def _record_solver_iteration(self, counter, iteration_coordinate, success, msg,
abs_error, rel_error, outputs, residuals):
"""
Record a solver iteration.
Parameters
----------
counter : int
The global counter associated with this iteration.
iteration_coordinate : str
The iteration coordinate to identify this iteration.
success : int
Integer to indicate success.
msg : str
The metadata message.
abs_error : float
The absolute error.
rel_error : float
The relative error.
outputs : [JSON]
The array of json objects representing the outputs.
residuals : [JSON]
The array of json objects representing the residuals.
"""
solver_iteration_dict = {
'counter': counter,
'iteration_coordinate': iteration_coordinate,
'success': success,
'msg': msg,
'abs_err': abs_error,
'rel_err': rel_error,
'solver_output': [] if outputs is None else outputs,
'solver_residuals': [] if residuals is None else residuals
}
global_iteration_dict = {
'record_type': 'solver',
'counter': counter
}
solver_iteration = json.dumps(solver_iteration_dict)
global_iteration = json.dumps(global_iteration_dict)
requests.post(self._endpoint + '/' + self._case_id + '/solver_iterations',
data=solver_iteration, headers=self._headers)
requests.post(self._endpoint + '/' + self._case_id + '/global_iterations',
data=global_iteration, headers=self._headers)