netbase.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:deep-prior 作者: moberweger 项目源码 文件源码
def load(self, filename):
        """
        Load the parameters for this network from disk.
        :param filename: Load the parameters of this network from a pickle file at the named path. If this name ends in
               ".gz" then the input will automatically be gunzipped; otherwise the input will be treated as a "raw" pickle.
        :return: None
        """

        opener = gzip.open if filename.lower().endswith('.gz') else open
        handle = opener(filename, 'rb')
        saved = cPickle.load(handle)
        handle.close()
        if saved['network'] != self.__str__():
            print "Possibly not matching network configuration!"
            differences = list(difflib.Differ().compare(saved['network'].splitlines(), self.__str__().splitlines()))
            print "Differences are:"
            print "\n".join(differences)
        for layer in self.layers:
            if len(layer.params) != len(saved['{}-values'.format(layer.layerNum)]):
                print "Warning: Layer parameters for layer {} do not match. Trying to fit on shape!".format(layer.layerNum)
                n_assigned = 0
                for p in layer.params:
                    for v in saved['{}-values'.format(layer.layerNum)]:
                        if p.get_value().shape == v.shape:
                            p.set_value(v)
                            n_assigned += 1

                if n_assigned != len(layer.params):
                    raise ImportError("Could not load all necessary variables!")
                else:
                    print "Found fitting parameters!"
            else:
                prms = layer.params
                for p, v in zip(prms, saved['{}-values'.format(layer.layerNum)]):
                    if p.get_value().shape == v.shape:
                        p.set_value(v)
                    else:
                        print "WARNING: Skipping parameter for {}! Shape {} does not fit {}.".format(p.name, p.get_value().shape, v.shape)
        print 'Loaded model parameters from {}'.format(filename)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号