python类savez()的实例源码

tagger.py 文件源码 项目:deep_srl 作者: luheng 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def save(self, filepath):
    """ Save model parameters to file.
    """
    all_params = OrderedDict([(param.name, param.get_value()) for param in self.params])
    numpy.savez(filepath, **all_params)
    print('Saved model to: {}'.format(filepath))
n02_convert.py 文件源码 项目:kaggle_yt8m 作者: N01Z3 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def tf2npz(tf_path, export_folder=FAST):
    vid_ids = []
    labels = []
    mean_rgb = []
    mean_audio = []
    tf_basename = os.path.basename(tf_path)
    npz_basename = tf_basename[:-len('.tfrecord')] + '.npz'
    isTrain = '/test' not in tf_path

    for example in tf.python_io.tf_record_iterator(tf_path):
        tf_example = tf.train.Example.FromString(example).features
        vid_ids.append(tf_example.feature['video_id'].bytes_list.value[0].decode(encoding='UTF-8'))
        if isTrain:
            labels.append(np.array(tf_example.feature['labels'].int64_list.value))
        mean_rgb.append(np.array(tf_example.feature['mean_rgb'].float_list.value).astype(np.float32))
        mean_audio.append(np.array(tf_example.feature['mean_audio'].float_list.value).astype(np.float32))

    save_path = export_folder + '/' + npz_basename
    np.savez(save_path,
             rgb=StandardScaler().fit_transform(np.array(mean_rgb)),
             audio=StandardScaler().fit_transform(np.array(mean_audio)),
             ids=np.array(vid_ids),
             labels=labels
             )
featurize_instances.py 文件源码 项目:quoll 作者: LanguageMachines 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def run(self):

        # generate dictionary of features
        features = {'tokens':{'n_list':self.ngrams.split(), 'blackfeats':self.blackfeats.split(), 'mt':self.minimum_token_frequency}}

        # format lines
        documents = [[doc] for doc in format_tokdoc(self.in_tokenized().path,self.lowercase)]

        # extract features
        ft = featurizer.Featurizer(documents, features)
        ft.fit_transform()
        instances, vocabulary = ft.return_instances(['tokens'])

        # write output
        numpy.savez(self.out_features().path, data=instances.data, indices=instances.indices, indptr=instances.indptr, shape=instances.shape)
        vocabulary = list(vocabulary)
        with open(self.out_vocabulary().path,'w',encoding='utf-8') as vocab_out:
            vocab_out.write('\n'.join(vocabulary))


# When the input is a directory with tokenized documents
featurize_instances.py 文件源码 项目:quoll 作者: LanguageMachines 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def run(self):

        # generate dictionary of features
        features = {'tokens':{'n_list':self.ngrams.split(), 'blackfeats':self.blackfeats.split(), 'mt':self.minimum_token_frequency}}

        # read in files and put in right format for featurizer
        documents = []
        for infile in sorted(listdir(self.in_tokdir().path),key=keyfunc):
            documents.append(format_tokdoc(self.in_tokdir().path + '/' + infile,self.lowercase))

        # extract features
        ft = featurizer.Featurizer(documents, features) # to prevent ngrams across sentences, a featurizer is generated per document
        ft.fit_transform()
        instances, vocabulary = ft.return_instances(['tokens'])

        # write output
        numpy.savez(self.out_features().path, data=instances.data, indices=instances.indices, indptr=instances.indptr, shape=instances.shape)
        with open(self.out_vocabulary().path,'w',encoding='utf-8') as vocab_out:
            vocab_out.write('\n'.join(vocabulary))


# when the input is a file with frogged documents
run_ga.py 文件源码 项目:quoll 作者: LanguageMachines 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def run(self):

        # read in vectors
        loader = numpy.load(self.in_vectors().path)
        instances = sparse.csr_matrix((loader['data'], loader['indices'], loader['indptr']), shape = loader['shape'])
        num_dimensions = instances.shape[1]

        # generate vectorpopulation
        random_vectorpopulation = ga_functions.random_vectorpopulation(num_dimensions, self.population_size)
        numpy.savez(self.out_vectorpopulation().path, data=random_vectorpopulation.data, indices=random_vectorpopulation.indices, indptr=random_vectorpopulation.indptr, shape=random_vectorpopulation.shape)

        # read in parameter options
        with open(self.in_parameter_options().path) as infile:
            lines = infile.read().rstrip().split('\n')
            parameter_options = [[i for i in range(len(line.split()))] for line in lines]

        # generate parameterpopulation
        random_parameterpopulation = ga_functions.random_parameterpopulation(parameter_options, self.population_size)
        numpy.savez(self.out_parameterpopulation().path, data=random_parameterpopulation.data, indices=random_parameterpopulation.indices, indptr=random_parameterpopulation.indptr, shape=random_parameterpopulation.shape)



################################################################################
###GA Iterator
################################################################################
npyio.py 文件源码 项目:radar 作者: amoose136 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def savez_compressed(file, *args, **kwds):
    """
    Save several arrays into a single file in compressed ``.npz`` format.

    If keyword arguments are given, then filenames are taken from the keywords.
    If arguments are passed in with no keywords, then stored file names are
    arr_0, arr_1, etc.

    Parameters
    ----------
    file : str
        File name of ``.npz`` file.
    args : Arguments
        Function arguments.
    kwds : Keyword arguments
        Keywords.

    See Also
    --------
    numpy.savez : Save several arrays into an uncompressed ``.npz`` file format
    numpy.load : Load the files created by savez_compressed.

    """
    _savez(file, args, kwds, True)
test_io.py 文件源码 项目:radar 作者: amoose136 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def test_closing_fid(self):
        # Test that issue #1517 (too many opened files) remains closed
        # It might be a "weak" test since failed to get triggered on
        # e.g. Debian sid of 2012 Jul 05 but was reported to
        # trigger the failure on Ubuntu 10.04:
        # http://projects.scipy.org/numpy/ticket/1517#comment:2
        with temppath(suffix='.npz') as tmp:
            np.savez(tmp, data='LOVELY LOAD')
            # We need to check if the garbage collector can properly close
            # numpy npz file returned by np.load when their reference count
            # goes to zero.  Python 3 running in debug mode raises a
            # ResourceWarning when file closing is left to the garbage
            # collector, so we catch the warnings.  Because ResourceWarning
            # is unknown in Python < 3.x, we take the easy way out and
            # catch all warnings.
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                for i in range(1, 1025):
                    try:
                        np.load(tmp)["data"]
                    except Exception as e:
                        msg = "Failed to load data from a file: %s" % e
                        raise AssertionError(msg)
test_io.py 文件源码 项目:radar 作者: amoose136 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def test_npzfile_dict():
    s = BytesIO()
    x = np.zeros((3, 3))
    y = np.zeros((3, 3))

    np.savez(s, x=x, y=y)
    s.seek(0)

    z = np.load(s)

    assert_('x' in z)
    assert_('y' in z)
    assert_('x' in z.keys())
    assert_('y' in z.keys())

    for f, a in z.items():
        assert_(f in ['x', 'y'])
        assert_equal(a.shape, (3, 3))

    assert_(len(z.items()) == 2)

    for f in z:
        assert_(f in ['x', 'y'])

    assert_('x' in z.keys())
test_io.py 文件源码 项目:radar 作者: amoose136 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def test_load_refcount():
    # Check that objects returned by np.load are directly freed based on
    # their refcount, rather than needing the gc to collect them.

    f = BytesIO()
    np.savez(f, [1, 2, 3])
    f.seek(0)

    assert_(gc.isenabled())
    gc.disable()
    try:
        gc.collect()
        np.load(f)
        # gc.collect returns the number of unreachable objects in cycles that
        # were found -- we are checking that no cycles were created by np.load
        n_objects_in_cycles = gc.collect()
    finally:
        gc.enable()
    assert_equal(n_objects_in_cycles, 0)
npz.py 文件源码 项目:cupy 作者: cupy 项目源码 文件源码 阅读 77 收藏 0 点赞 0 评论 0
def savez(file, *args, **kwds):
    """Saves one or more arrays into a file in uncompressed ``.npz`` format.

    Arguments without keys are treated as arguments with automatic keys named
    ``arr_0``, ``arr_1``, etc. corresponding to the positions in the argument
    list. The keys of arguments are used as keys in the ``.npz`` file, which
    are used for accessing NpzFile object when the file is read by
    :func:`cupy.load` function.

    Args:
        file (file or str): File or filename to save.
        *args: Arrays with implicit keys.
        **kwds: Arrays with explicit keys.

    .. seealso:: :func:`numpy.savez`

    """
    args = map(cupy.asnumpy, args)
    for key in kwds:
        kwds[key] = cupy.asnumpy(kwds[key])
    numpy.savez(file, *args, **kwds)
rl_tuner.py 文件源码 项目:magenta 作者: tensorflow 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def save_stored_rewards(self, file_name):
    """Saves the models stored rewards over time in a .npz file.

    Args:
      file_name: Name of the file that will be saved.
    """
    training_epochs = len(self.rewards_batched) * self.output_every_nth
    filename = os.path.join(self.output_dir,
                            file_name + '-' + str(training_epochs))
    np.savez(filename,
             train_rewards=self.rewards_batched,
             train_music_theory_rewards=self.music_theory_rewards_batched,
             train_note_rnn_rewards=self.note_rnn_rewards_batched,
             eval_rewards=self.eval_avg_reward,
             eval_music_theory_rewards=self.eval_avg_music_theory_reward,
             eval_note_rnn_rewards=self.eval_avg_note_rnn_reward,
             target_val_list=self.target_val_list)
mainLoop.py 文件源码 项目:NMT 作者: tuzhaopeng 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def test(self):
        self.model.best_params = [(x.name, x.get_value()) for x in
                                  self.model.params]
        numpy.savez(self.state['prefix'] + '_best_params',
                    **dict(self.model.best_params))
        self.state['best_params_pos'] = self.step
        if self.test_data is not None:
            rvals = self.model.validate(self.test_data)
        else:
            rvals = []
        msg = '>>>         Test'
        pos = self.step // self.state['validFreq']
        for k, v in rvals:
            msg = msg + ' ' + k + ':%6.3f ' % v
            self.timings['test' + k][pos] = float(v)
            self.state['test' + k] = float(v)
        print msg
        self.state['testtime'] = float(time.time()-self.start_time)/60.
mainLoop.py 文件源码 项目:NMT 作者: tuzhaopeng 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def save(self):
        start = time.time()
        print "Saving the model..."

        # ignore keyboard interrupt while saving
        s = signal.signal(signal.SIGINT, signal.SIG_IGN)
        numpy.savez(self.state['prefix']+'timing.npz',
                    **self.timings)
        if self.state['overwrite']:
            self.model.save(self.state['prefix']+'model.npz')
        else:
            self.model.save(self.state['prefix'] +
                            'model%d.npz' % self.save_iter)
        cPickle.dump(self.state, open(self.state['prefix']+'state.pkl', 'w'))
        self.save_iter += 1
        signal.signal(signal.SIGINT, s)

        print "Model saved, took {}".format(time.time() - start)

    # FIXME
mainLoop.py 文件源码 项目:NMT 作者: tuzhaopeng 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def test(self):
        self.model.best_params = [(x.name, x.get_value()) for x in
                                  self.model.params]
        numpy.savez(self.state['prefix'] + '_best_params',
                    **dict(self.model.best_params))
        self.state['best_params_pos'] = self.step
        if self.test_data is not None:
            rvals = self.model.validate(self.test_data)
        else:
            rvals = []
        msg = '>>>         Test'
        pos = self.step // self.state['validFreq']
        for k, v in rvals:
            msg = msg + ' ' + k + ':%6.3f ' % v
            self.timings['test' + k][pos] = float(v)
            self.state['test' + k] = float(v)
        print msg
        self.state['testtime'] = float(time.time()-self.start_time)/60.
windeval.py 文件源码 项目:POWER 作者: pennelise 项目源码 文件源码 阅读 37 收藏 0 点赞 0 评论 0
def save_total_power(data,times,SCADA_faults,filename):
    total_power = np.array([])
    new_times = np.array([])
    percent_active = np.array([])    
    for time in np.unique(times):
        state_fault = SCADA_faults[times == time]
        fault_mask = [state_fault == 2,state_fault == 1]
        fault_mask = reduce(np.logical_or,fault_mask)

        total_power = np.append(total_power,np.sum(data[times == time]))
        new_times = np.append(new_times,time)
        percent_active = np.append(percent_active,float(np.sum(fault_mask))/float(len(fault_mask)))


    total_dictionary = {}
    total_dictionary['total_power'] = total_power
    total_dictionary['time'] = new_times
    total_dictionary['percent_active'] = percent_active

    file_path = os.path.normpath('%s/FormattedData/%s' % (os.getcwd(),filename))
    np.savez(file_path,**total_dictionary)
write_htk_npz.py 文件源码 项目:recipe_zs2017_track2 作者: kamperh 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def main():
    args = check_argv()

    print datetime.datetime.now()

    print "Reading HTK features from directory:", args.htk_dir
    npz_dict = {}
    n_feat_files = 0
    for feat_fn in glob.glob(path.join(args.htk_dir, "*." + args.extension)):
        hlist_output = shell("HList -r " + feat_fn)
        features = [
            [float(i) for i in line.split(" ") if i != ""] for line in
            hlist_output.split("\n") if line != ""
            ]
        key = path.splitext(path.split(feat_fn)[-1])[0]
        npz_dict[key] = np.array(features)
        n_feat_files += 1
    print "Read", n_feat_files, "feature files"

    print "Writing Numpy archive:", args.npz_fn
    np.savez(args.npz_fn, **npz_dict)

    print datetime.datetime.now()
okko_to_npz.py 文件源码 项目:recipe_zs2017_track2 作者: kamperh 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def main():
    args = check_argv()

    print("Reading:", args.mat_fn)
    mat = tables.open_file(args.mat_fn)

    n_audio = mat.root.files_train[0].shape[0]
    print("No. audio files:", n_audio)

    filenames = []
    for i_audio in xrange(n_audio):
        filenames.append("".join([chr(i[0]) for i in mat.root.files_train[0][i_audio][0]]))
    audio_keys = [path.splitext(path.split(i)[-1])[0] for i in filenames]

    features_dict = {}
    for i_audio in xrange(n_audio):
        features = mat.root.F_train_iter[0][i_audio][0]
        features_dict[audio_keys[i_audio].replace("_", "-")] = features.T

    print("Writing:", args.npz_fn)
    np.savez(args.npz_fn, **features_dict)
lstm_theanompi_outdated.py 文件源码 项目:Theano-MPI 作者: uoguelph-mlrg 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def cleanup(self,*args, **kwargs):

        from theanompi.models.lstm import zipp, unzip, get_minibatches_idx, pred_error

        if self.best_p is not None:
            zipp(self.best_p, self.tparams)
        else:
            self.best_p = unzip(self.tparams)

        self.use_noise.set_value(0.)
        kf_train_sorted = get_minibatches_idx(len(self.train[0]), self.model_options['batch_size'])
        train_err = pred_error(self.f_pred, self.prepare_data, self.train, kf_train_sorted)
        valid_err = pred_error(self.f_pred, self.prepare_data, self.valid, kf_valid)
        test_err = pred_error(self.f_pred, self.prepare_data, self.test, kf_test)

        if self.rank==0: print( 'Train ', train_err, 'Valid ', valid_err, 'Test ', test_err )
        if saveto:
            numpy.savez(self.model_options['saveto'], train_err=train_err,
                        valid_err=valid_err, test_err=test_err,
                        history_errs=self.history_errs, **self.best_p)
        # print('The code run for %d epochs, with %f sec/epochs' % (
        #     (self.eidx + 1), (end_time - start_time) / (1. * (self.eidx + 1))))
        # print( ('Training took %.1fs' %
        #         (end_time - start_time)), file=sys.stderr)
test_hists.py 文件源码 项目:scikit-hep 作者: scikit-hep 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def test_error_bars_stacked2(cmdopt, data_gen):

    output = skh_plt.hist([data_gen[0], data_gen[1]], bins=20, histtype='barstacked',
                          weights=[data_gen[2], data_gen[2]], errorbars=True, err_return=True,
                          normed=True, scale=2)

    if cmdopt == "generate":
        with open(answer_dir+'/answers_error_bars_stacked2.npz', 'wb') as f:
            np.savez(f, bc=output[0], be=output[1], berr=output[2])
        plt.title('test_error_bars_stacked2')
        plt.show()
    elif cmdopt == "test":
        answers = np.load(answer_dir+'/answers_error_bars_stacked2.npz')
        assert(np.all(output[0] == answers['bc']))
        assert(np.all(output[1] == answers['be']))
        assert(np.all(output[2] == answers['berr']))
test_hists.py 文件源码 项目:scikit-hep 作者: scikit-hep 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def test_error_bars_stacked3(cmdopt, data_gen):

    output = skh_plt.hist([data_gen[0], data_gen[1]], bins=20, histtype='step', stacked=True,
                          weights=[data_gen[2], data_gen[2]], errorbars=True, err_return=True,
                          normed=True, scale=2)

    if cmdopt == "generate":
        with open(answer_dir+'/answers_error_bars_stacked3.npz', 'wb') as f:
            np.savez(f, bc=output[0], be=output[1], berr=output[2])
        plt.title('test_error_bars_stacked2')
        plt.show()
    elif cmdopt == "test":
        answers = np.load(answer_dir+'/answers_error_bars_stacked3.npz')
        assert(np.all(output[0] == answers['bc']))
        assert(np.all(output[1] == answers['be']))
        assert(np.all(output[2] == answers['berr']))
test_hists.py 文件源码 项目:scikit-hep 作者: scikit-hep 项目源码 文件源码 阅读 36 收藏 0 点赞 0 评论 0
def test_ratio_plot(cmdopt, data_gen):

    output = skh_plt.ratio_plot(dict(x=data_gen[0], errorbars=True, histtype='marker'),
                                dict(x=data_gen[1], weights=data_gen[2], errorbars=True))

    if cmdopt == "generate":
        with open(answer_dir+'/answers_ratio_plot.npz', 'wb') as f:
            np.savez(f, bc1=output[1][0], be1=output[1][1],
                     bc2=output[2][0], be2=output[2][1])
        output[0][0].set_title('test_ratio_plot')
        plt.show()
    elif cmdopt == "test":
        answers = np.load(answer_dir+'/answers_ratio_plot.npz')
        assert(np.all(output[1][0] == answers['bc1']))
        assert(np.all(output[1][1] == answers['be1']))
        assert(np.all(output[2][0] == answers['bc2']))
        assert(np.all(output[2][1] == answers['be2']))
test_hists.py 文件源码 项目:scikit-hep 作者: scikit-hep 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def test_ratio_plot_log(cmdopt, data_gen):

    output = skh_plt.ratio_plot(dict(x=data_gen[0], errorbars=True, histtype='marker', log=True,
                                     err_x=False),
                                dict(x=data_gen[1], weights=data_gen[2], errorbars=True),
                                logx=True, ratio_range=(0, 10))

    if cmdopt == "generate":
        with open(answer_dir+'/answers_ratio_plot_log.npz', 'wb') as f:
            np.savez(f, bc1=output[1][0], be1=output[1][1],
                     bc2=output[2][0], be2=output[2][1])
        output[0][0].set_title('test_ratio_plot_log')
        plt.show()
    elif cmdopt == "test":
        answers = np.load(answer_dir+'/answers_ratio_plot_log.npz')
        assert(np.all(output[1][0] == answers['bc1']))
        assert(np.all(output[1][1] == answers['be1']))
        assert(np.all(output[2][0] == answers['bc2']))
        assert(np.all(output[2][1] == answers['be2']))
io.py 文件源码 项目:fg21sim 作者: liweitianux 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def write_dndlnm(outfile, dndlnm, z, mass, clobber=False):
    """
    Write the halo mass distribution data into file in NumPy's ".npz"
    format, which packs the ``dndlnm``, ``z``, and ``mass`` arrays.

    Parameters
    ----------
    outfile : str
        The output file to store the dndlnm data, in ".npz" format.
    dndlnm : 2D float `~numpy.ndarray`
        Shape: (len(z), len(mass))
        Differential mass function in terms of natural log of M.
        Unit: [Mpc^-3] (the little "h" is folded into the values)
    z : 1D float `~numpy.ndarray`
        Redshifts where the halo mass distribution is calculated.
    mass : 1D float `~numpy.ndarray`
        (Logarithmic-distributed) masses points.
        Unit: [Msun] (the little "h" is folded into the values)
    clobber : bool, optional
        Whether to overwrite the existing output file?
    """
    _create_dir(outfile)
    _check_existence(outfile, clobber=clobber, remove=True)
    np.savez(outfile, dndlnm=dndlnm, z=z, mass=mass)
write_rollout_data.py 文件源码 项目:gym 作者: openai 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("envid")
    parser.add_argument("outfile")
    parser.add_argument("--gymdir")

    args = parser.parse_args()
    if args.gymdir:
        sys.path.insert(0, args.gymdir)
    import gym
    from gym import utils
    print utils.colorize("gym directory: %s"%path.dirname(gym.__file__), "yellow")
    env = gym.make(args.envid)
    agent = RandomAgent(env.action_space)
    alldata = {}
    for i in xrange(2):
        np.random.seed(i)
        data = rollout(env, agent, env.spec.max_episode_steps)
        for (k, v) in data.items():
            alldata["%i-%s"%(i, k)] = v
    np.savez(args.outfile, **alldata)
npyio.py 文件源码 项目:krpcScripts 作者: jwvanderbeck 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def savez_compressed(file, *args, **kwds):
    """
    Save several arrays into a single file in compressed ``.npz`` format.

    If keyword arguments are given, then filenames are taken from the keywords.
    If arguments are passed in with no keywords, then stored file names are
    arr_0, arr_1, etc.

    Parameters
    ----------
    file : str
        File name of ``.npz`` file.
    args : Arguments
        Function arguments.
    kwds : Keyword arguments
        Keywords.

    See Also
    --------
    numpy.savez : Save several arrays into an uncompressed ``.npz`` file format
    numpy.load : Load the files created by savez_compressed.

    """
    _savez(file, args, kwds, True)
test_io.py 文件源码 项目:krpcScripts 作者: jwvanderbeck 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def test_closing_fid(self):
        # Test that issue #1517 (too many opened files) remains closed
        # It might be a "weak" test since failed to get triggered on
        # e.g. Debian sid of 2012 Jul 05 but was reported to
        # trigger the failure on Ubuntu 10.04:
        # http://projects.scipy.org/numpy/ticket/1517#comment:2
        with temppath(suffix='.npz') as tmp:
            np.savez(tmp, data='LOVELY LOAD')
            # We need to check if the garbage collector can properly close
            # numpy npz file returned by np.load when their reference count
            # goes to zero.  Python 3 running in debug mode raises a
            # ResourceWarning when file closing is left to the garbage
            # collector, so we catch the warnings.  Because ResourceWarning
            # is unknown in Python < 3.x, we take the easy way out and
            # catch all warnings.
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                for i in range(1, 1025):
                    try:
                        np.load(tmp)["data"]
                    except Exception as e:
                        msg = "Failed to load data from a file: %s" % e
                        raise AssertionError(msg)
test_io.py 文件源码 项目:krpcScripts 作者: jwvanderbeck 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def test_npzfile_dict():
    s = BytesIO()
    x = np.zeros((3, 3))
    y = np.zeros((3, 3))

    np.savez(s, x=x, y=y)
    s.seek(0)

    z = np.load(s)

    assert_('x' in z)
    assert_('y' in z)
    assert_('x' in z.keys())
    assert_('y' in z.keys())

    for f, a in z.items():
        assert_(f in ['x', 'y'])
        assert_equal(a.shape, (3, 3))

    assert_(len(z.items()) == 2)

    for f in z:
        assert_(f in ['x', 'y'])

    assert_('x' in z.keys())
test_io.py 文件源码 项目:krpcScripts 作者: jwvanderbeck 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def test_load_refcount():
    # Check that objects returned by np.load are directly freed based on
    # their refcount, rather than needing the gc to collect them.

    f = BytesIO()
    np.savez(f, [1, 2, 3])
    f.seek(0)

    assert_(gc.isenabled())
    gc.disable()
    try:
        gc.collect()
        np.load(f)
        # gc.collect returns the number of unreachable objects in cycles that
        # were found -- we are checking that no cycles were created by np.load
        n_objects_in_cycles = gc.collect()
    finally:
        gc.enable()
    assert_equal(n_objects_in_cycles, 0)
convnade.py 文件源码 项目:NADE 作者: MarcCote 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def save(self, path):
        savedir = smartutils.create_folder(pjoin(path, type(self).__name__))

        hyperparameters = {'version': 2,
                           'image_shape': self.image_shape,
                           'nb_channels': self.nb_channels,
                           'ordering_seed': self.ordering_seed,
                           'use_mask_as_input': self.use_mask_as_input,
                           'hidden_activation': self.hidden_activation,
                           'has_convnet': self.has_convnet,
                           'has_fullnet': self.has_fullnet}
        smartutils.save_dict_to_json_file(pjoin(savedir, "meta.json"), {"name": self.__class__.__name__})
        smartutils.save_dict_to_json_file(pjoin(savedir, "hyperparams.json"), hyperparameters)

        # Save residual parameters for the projection shortcuts.
        np.savez(pjoin(savedir, "params.npz"), *self.parameters)
batch_schedulers.py 文件源码 项目:NADE 作者: MarcCote 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def __len__(self):
        return self.D

    # def save(self, savedir):
    #     state = {"version": 1,
    #              "seed": self.seed,
    #              "use_mask_as_input": self.use_mask_as_input,
    #              "batch_size": self.batch_size,
    #              "shared_batch_count": self.shared_batch_count.get_value(),
    #              "rng": pickle.dumps(self.rng),
    #              "shared_batch_mask": self._shared_mask_o_lt_d.get_value(),
    #              }

    #     np.savez(pjoin(savedir, 'mini_batch_scheduler_with_autoregressive_mask.npz'), **state)

    # def load(self, loaddir):
    #     state = np.load(pjoin(loaddir, 'mini_batch_scheduler_with_autoregressive_mask.npz'))
    #     self.batch_size = state["batch_size"]
    #     self.shared_batch_count.set_value(state["shared_batch_count"])
    #     self.rng = pickle.loads(state["rng"])
    #     self._shared_mask_o_lt_d.set_value(state["shared_batch_mask"])


问题


面经


文章

微信
公众号

扫码关注公众号