def load(self, filename):
"""
Load the parameters for this network from disk.
:param filename: Load the parameters of this network from a pickle file at the named path. If this name ends in
".gz" then the input will automatically be gunzipped; otherwise the input will be treated as a "raw" pickle.
:return: None
"""
opener = gzip.open if filename.lower().endswith('.gz') else open
handle = opener(filename, 'rb')
saved = cPickle.load(handle)
handle.close()
if saved['network'] != self.__str__():
print "Possibly not matching network configuration!"
differences = list(difflib.Differ().compare(saved['network'].splitlines(), self.__str__().splitlines()))
print "Differences are:"
print "\n".join(differences)
for layer in self.layers:
if len(layer.params) != len(saved['{}-values'.format(layer.layerNum)]):
print "Warning: Layer parameters for layer {} do not match. Trying to fit on shape!".format(layer.layerNum)
n_assigned = 0
for p in layer.params:
for v in saved['{}-values'.format(layer.layerNum)]:
if p.get_value().shape == v.shape:
p.set_value(v)
n_assigned += 1
if n_assigned != len(layer.params):
raise ImportError("Could not load all necessary variables!")
else:
print "Found fitting parameters!"
else:
prms = layer.params
for p, v in zip(prms, saved['{}-values'.format(layer.layerNum)]):
if p.get_value().shape == v.shape:
p.set_value(v)
else:
print "WARNING: Skipping parameter for {}! Shape {} does not fit {}.".format(p.name, p.get_value().shape, v.shape)
print 'Loaded model parameters from {}'.format(filename)
评论列表
文章目录