python类assert_array_equal()的实例源码

test_pastream.py 文件源码 项目:pastream 作者: tgarc 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def assert_soundfiles_equal(inp_fh, out_fh, preamble, dtype):
    delay = find_soundfile_delay(out_fh, preamble, dtype)
    assert delay != -1, "Test Preamble pattern not found"
    out_fh.seek(delay)

    mframes = 0
    blocksize = 2048
    unsigned_dtype = 'u' + dtype.lstrip('u')
    inpblk = np.zeros((blocksize, inp_fh.channels), dtype=dtype)
    for outblk in out_fh.blocks(blocksize, dtype=dtype, always_2d=True):
        readframes = inp_fh.buffer_read_into(inpblk[:len(outblk)], dtype=dtype)

        inp = inpblk[:readframes].view(unsigned_dtype)
        out = outblk.view(unsigned_dtype)

        npt.assert_array_equal(inp, out, "Loopback data mismatch")
        mframes += readframes

    print("Matched %d of %d frames; Initial delay of %d frames; %d frames truncated"
          % (mframes, len(inp_fh), delay, len(inp_fh) - inp_fh.tell()))
mpi_sens_test.py 文件源码 项目:F_UNCLE 作者: fraserphysics 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def test_multi_solve_pll(models, sim):

    dofs = []
    for i in range(20):
        dofs.append(models['simp'].get_dof() + i)
    # end

    results = sim.multi_solve_mpi(models, ['simp',], dofs, verb=False)


    # for i in results:
    #     print('rank {:d} item{:d} {:}'.format(myrank, i, results[i][1][0]))

    for i, res in enumerate(results):        
        npt.assert_array_equal(
            res[1][0],
            np.arange(10) * (np.arange(10) + i + 1),
            err_msg='Error in {:d} dof set'.format(i)
        )
test_hwbi_internal.py 文件源码 项目:hwbi_app 作者: quanted 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_cyan_200():
        test_name = "Check page access "
        try:
            assert_error = False
            response = [requests.get(p).status_code for p in check_pages]
            try:
                npt.assert_array_equal(response, 200, '200 error', True)
            except AssertionError:
                assert_error = True
            except Exception as e:
                # handle any other exception
                print("Error '{0}' occured. Arguments {1}.".format(e.message, e.args))
        except Exception as e:
            # handle any other exception
            print("Error '{0}' occured. Arguments {1}.".format(e.message, e.args))
        finally:
            linkcheck_helper.write_report(test_name, assert_error, check_pages, response)
        return
test_hwbi_internal.py 文件源码 项目:hwbi_app 作者: quanted 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def test_cyan_api_endpoints_200():
        test_name = "Check page access "
        try:
            assert_error = False
            response = [requests.get(p).status_code for p in api_endpoints]
            try:
                npt.assert_array_equal(response, 200, '200 error', True)
            except AssertionError:
                assert_error = True
            except Exception as e:
                # handle any other exception
                print("Error '{0}' occured. Arguments {1}.".format(e.message, e.args))
        except Exception as e:
            # handle any other exception
            print("Error '{0}' occured. Arguments {1}.".format(e.message, e.args))
        finally:
            linkcheck_helper.write_report(test_name, assert_error, api_endpoints, response)
        return

# unittest will
# 1) call the setup method,
# 2) then call every method starting with "test",
# 3) then the teardown method
test_mrcz.py 文件源码 项目:python-mrcz 作者: em-MRCZ 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def compReadWrite(self, testMage, casttype=None, compressor=None, clevel = 1 ):
        # This is the main functions which reads and writes from disk.
        mrcName = os.path.join( tmpDir, "testMage.mrc" )
        pixelsize = np.array( [1.2, 2.6, 3.4] )

        mrcz.writeMRC( testMage, mrcName, dtype=casttype,
                           pixelsize=pixelsize, pixelunits=u"\AA",
                           voltage=300.0, C3=2.7, gain=1.05,
                           compressor=compressor, clevel=clevel )

        rereadMage, rereadHeader = mrcz.readMRC( mrcName, pixelunits=u"\AA")
        try: os.remove( mrcName )
        except IOError: log.info( "Warning: file {} left on disk".format(mrcName) )

        npt.assert_array_almost_equal( testMage, rereadMage )
        npt.assert_array_equal( rereadHeader['voltage'], 300.0 )
        npt.assert_array_almost_equal( rereadHeader['pixelsize'], pixelsize )
        npt.assert_array_equal( rereadHeader['pixelunits'], u"\AA" )
        npt.assert_array_equal( rereadHeader['C3'], 2.7 )
        npt.assert_array_equal( rereadHeader['gain'], 1.05 )
test_annos.py 文件源码 项目:deepcpg 作者: cangermueller 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def test_join_overlapping():
    f = annos.join_overlapping

    s, e = f([], [])
    assert len(s) == 0
    assert len(e) == 0

    s = [1, 3, 6]
    e = [2, 4, 10]
    expect = (s, e)
    result = f(s, e)
    assert result == expect

    x = np.array([[1, 2],
                  [3, 4], [4, 5],
                  [6, 8], [8, 8], [8, 9],
                  [10, 15], [10, 11], [11, 14], [14, 16]]
                 )
    expect = [[1, 2], [3, 5], [6, 9], [10, 16]]
    result = np.array(f(x[:, 0], x[:, 1])).T
    npt.assert_array_equal(result, expect)
test_annos.py 文件源码 项目:deepcpg 作者: cangermueller 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def test_in_which():
    f = annos.in_which
    ys = [2, 4, 12, 17]
    ye = [2, 8, 15, 18]

    x = []
    expect = []
    result = f(x, ys, ye)
    npt.assert_array_equal(result, expect)

    x = [-1, 3, 9, 19]
    expect = [-1, -1, -1, -1]
    result = f(x, ys, ye)
    npt.assert_array_equal(result, expect)

    x = [-1, 2, 2, 3, 4, 8, 15, 16]
    expect = [-1, 0, 0, -1, 1, 1, 2, -1]
    result = f(x, ys, ye)
    npt.assert_array_equal(result, expect)
test_annos.py 文件源码 项目:deepcpg 作者: cangermueller 项目源码 文件源码 阅读 16 收藏 0 点赞 0 评论 0
def test_distance():
    start = [3, 10, 17]
    end = [6, 15, 18]
    pos = [1, 2, 5, 8, 10, 15, 16, 19]
    expect = [2, 1, 0, 2, 0, 0, 1, 1]
    start = np.asarray(start)
    end = np.asarray(end)
    pos = np.asarray(pos)
    actual = annos.distance(pos, start, end)
    npt.assert_array_equal(actual, expect)

    pos = [1, 6, 7, 9]
    expect = [2, 0, 1, 1]
    start = np.asarray(start)
    end = np.asarray(end)
    pos = np.asarray(pos)
    actual = annos.distance(pos, start, end)
    npt.assert_array_equal(actual, expect)
test_feature_extractor.py 文件源码 项目:deepcpg 作者: cangermueller 项目源码 文件源码 阅读 15 收藏 0 点赞 0 评论 0
def test_join_intervals(self):
        f = fe.IntervalFeatureExtractor.join_intervals

        s, e = f([], [])
        assert len(s) == 0
        assert len(e) == 0

        s = [1, 3, 6]
        e = [2, 4, 10]
        expect = (s, e)
        result = f(s, e)
        assert result == expect

        x = np.array([[1, 2],
                      [3, 4], [4, 5],
                      [6, 8], [8, 8], [8, 9],
                      [10, 15], [10, 11], [11, 14], [14, 16]])
        expect = [[1, 2], [3, 5], [6, 9], [10, 16]]
        result = np.array(f(x[:, 0], x[:, 1])).T
        npt.assert_array_equal(result, expect)
test_feature_extractor.py 文件源码 项目:deepcpg 作者: cangermueller 项目源码 文件源码 阅读 16 收藏 0 点赞 0 评论 0
def test_index_intervals(self):
        f = fe.IntervalFeatureExtractor.index_intervals
        ys = [2, 4, 12, 17]
        ye = [2, 8, 15, 18]

        x = []
        expect = []
        result = f(x, ys, ye)
        npt.assert_array_equal(result, expect)

        x = [-1, 3, 9, 19]
        expect = [-1, -1, -1, -1]
        result = f(x, ys, ye)
        npt.assert_array_equal(result, expect)

        x = [-1, 2, 2, 3, 4, 8, 15, 16]
        expect = [-1, 0, 0, -1, 1, 1, 2, -1]
        result = f(x, ys, ye)
        npt.assert_array_equal(result, expect)
test_feature_extractor.py 文件源码 项目:deepcpg 作者: cangermueller 项目源码 文件源码 阅读 14 收藏 0 点赞 0 评论 0
def test_k1(self):
        ext = fe.KmersFeatureExtractor(1)

        seqs = self._translate_seqs('AGGTTCCC')
        expect = self._freq({'A': 1, 'G': 2, 'T': 2, 'C': 3})
        expect = np.array([expect])
        actual = ext(seqs)
        npt.assert_array_equal(actual, expect)

        seqs = self._translate_seqs('AGTGGGTTCCC')
        expect = self._freq({'A': 1, 'G': 4, 'T': 3, 'C': 3})
        expect = np.array([expect])
        actual = ext(seqs)
        npt.assert_array_equal(actual, expect)

        seqs = self._translate_seqs(['AGTGGGTTCCC',
                                     'GGGGGGGGGGG'])
        expect = []
        expect.append(self._freq({'A': 1, 'G': 4, 'T': 3, 'C': 3}))
        expect.append(self._freq({'G': 11}))
        expect = np.array(expect)
        actual = ext(seqs)
        npt.assert_array_equal(actual, expect)
test_autodiff_cpu.py 文件源码 项目:Aurora 作者: upul 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def test_reshape():
    x2 = ad.Variable(name='x2')
    y = ad.reshape(x2, newshape=(1, 4))

    grad_x2, = ad.gradients(y, [x2])
    executor = ad.Executor([y, grad_x2])
    x2_val = np.random.randn(2, 2)
    y_val, grad_x2_val = executor.run(feed_shapes={x2: x2_val})

    assert isinstance(y, ad.Node)
    assert y_val.shape == (1, 4)
    npt.assert_array_equal(grad_x2_val, np.ones((2, 2)))

    # x2 = ad.Variable(name='x2')
    # y = ad.reshape(x2, newshape=(2, 1, 2, 3))
    # grad_x2, = ad.gradients(y, [x2])
    # executor = ad.Executor([y, grad_x2])
    # x2_val = np.random.randn(2, 6)
    # y_val, grad_x2_val = executor.run(feed_shapes={x2: x2_val})
    #
    # assert isinstance(y, ad.Node)
    # assert y_val.shape == (2, 1, 2, 3)
    # npt.assert_array_equal(grad_x2_val, np.ones((2, 1, 2, 3)))
fileio_tests.py 文件源码 项目:GulpIO 作者: TwentyBN 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def test_read_frames_fixed_length(self):
        # use 'write_frame' to write a single image
        self.gulp_chunk.meta_dict = OrderedDict()
        self.gulp_chunk.fp = BytesIO()
        image = np.ones((1, 4), dtype='uint8')
        with mock.patch('cv2.imencode') as imencode_mock:
            imencode_mock.return_value = '', np.ones((1, 4), dtype='uint8')
            self.gulp_chunk._write_frame(0, image)
        self.gulp_chunk.meta_dict['0']['meta_data'].append({})
        with mock.patch('cv2.imdecode', lambda x, y:
                        np.array(x).reshape((1, 4))):
            with mock.patch('cv2.cvtColor', lambda x, y: x):
                # recover the single frame using 'read'
                frames, meta = self.gulp_chunk.read_frames('0')
        npt.assert_array_equal(image, np.array(frames[0]))
        self.assertEqual({}, meta)
test_auxiliary_methods.py 文件源码 项目:LabelsManager 作者: SebastianoF 项目源码 文件源码 阅读 16 收藏 0 点赞 0 评论 0
def test_set_new_data_simple_modifications():

    aff = np.eye(4); aff[2, 1] = 42.0

    im_0 = nib.Nifti1Image(np.zeros([3,3,3]), affine=aff)
    im_0_header = im_0.header
    # default intent_code
    assert_equals(im_0_header['intent_code'], 0)
    # change intento code
    im_0_header['intent_code'] = 5

    # generate new nib from the old with new data
    im_1 = set_new_data(im_0, np.ones([3,3,3]))
    im_1_header = im_1.header
    # see if the infos are the same as in the modified header
    assert_array_equal(im_1.get_data()[:], np.ones([3,3,3]))
    assert_equals(im_1_header['intent_code'], 5)
    assert_array_equal(im_1.get_affine(), aff)
test_auxiliary_methods.py 文件源码 项目:LabelsManager 作者: SebastianoF 项目源码 文件源码 阅读 15 收藏 0 点赞 0 评论 0
def test_binarise_a_matrix():

    in_data = np.array([0, 1, 2, 3, 4])
    expected_out_data = np.array([0, 1, 1, 1, 1])
    assert_array_equal(expected_out_data, binarise_a_matrix(in_data, dtype=np.int))


# def test_get_values_below_label():
#
#     image = np.array(range(8 * 8)).reshape(8, 8)
#     mask = np.zeros_like(image)
#     mask[2, 2] = 1
#     mask[2, 3] = 1
#     mask[3, 2] = 1
#     mask[3, 3] = 1
#     vals = get_values_below_label(image, mask, 1)
#     assert_array_equal([image[2, 2], image[2, 3], image[3, 2], image[3, 3]], vals)
data_sources_test.py 文件源码 项目:speech_ml 作者: coopie 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_waveform_data_source(self):
        ds = WaveformDataSource(FileDataSource(DUMMY_DATA_PATH, suffix='.wav'), process_waveform=dummy_process_waveforms)

        self.assertTrue(
            np.all(
                ds['1_sad_kid_1'] == np.array(2)
            )
        )

        paths = [os.path.join(DUMMY_DATA_PATH, f) for f in os.listdir(DUMMY_DATA_PATH) if f.endswith('.wav')]
        filenames = [x.split(os.sep)[-1].split('.')[0] for x in paths]

        npt.assert_array_equal(
            np.array([ds[f] for f in filenames]),
            np.array([dummy_process_waveforms(p)[1] for p in paths])
        )
data_sources_test.py 文件源码 项目:speech_ml 作者: coopie 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_spectrogram_data_source(self):
        ds = \
            SpectrogramDataSource(
                WaveformDataSource(
                    FileDataSource(DUMMY_DATA_PATH, suffix='.wav'),
                    process_waveform=dummy_process_waveforms),
                dummy_process_spectrograms
            )

        self.assertTrue(
            np.all(
                ds['1_sad_kid_1'] == np.eye(2) * 2
            )
        )

        paths = [os.path.join(DUMMY_DATA_PATH, f) for f in os.listdir(DUMMY_DATA_PATH) if f.endswith('.wav')]
        filenames = [x.split(os.sep)[-1].split('.')[0] for x in paths]

        npt.assert_array_equal(
            np.array([ds[f] for f in filenames]),
            np.array([dummy_process_spectrograms(dummy_process_waveforms(p)[1])[-1] for p in paths])
        )
test_io.py 文件源码 项目:python-smeftrunner 作者: DsixTools 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_load(self):
        sm = pkgutil.get_data('smeftrunner', 'tests/data/SMInput-CPV.dat').decode('utf-8')
        wc = pkgutil.get_data('smeftrunner', 'tests/data/WCsInput-CPV-SMEFT.dat').decode('utf-8')
        wcout = pkgutil.get_data('smeftrunner', 'tests/data/Output_SMEFTrunner.dat').decode('utf-8')
        io.sm_lha2dict(pylha.load(sm))
        io.wc_lha2dict(pylha.load(wc))
        CSM = io.sm_lha2dict(pylha.load(wcout))
        C = io.wc_lha2dict(pylha.load(wcout))
        C2 = io.wc_lha2dict(io.wc_dict2lha(C))
        for k in C:
            npt.assert_array_equal(C[k], C2[k])
        smeft = SMEFT()
        smeft.load_initial((wcout,))
        for k in C:
            npt.assert_array_equal(definitions.symmetrize(C)[k], smeft.C_in[k], err_msg="Failed for {}".format(k))
        for k in CSM:
            npt.assert_array_equal(definitions.symmetrize(CSM)[k], smeft.C_in[k], err_msg="Failed for {}".format(k))
        CSM2 = io.sm_lha2dict(io.sm_dict2lha(CSM))
        for k in CSM:
            npt.assert_array_equal(CSM[k], CSM2[k], err_msg="Failed for {}".format(k))
test_predict.py 文件源码 项目:cesium_web 作者: cesium-ml 项目源码 文件源码 阅读 16 收藏 0 点赞 0 评论 0
def test_download_prediction_csv_class_prob(driver, project, dataset,
                                            featureset, model, prediction):
    driver.get('/')
    _click_download(project.id, driver)
    assert os.path.exists('/tmp/cesium_prediction_results.csv')
    try:
        result = pd.read_csv('/tmp/cesium_prediction_results.csv')
        npt.assert_array_equal(result.ts_name, np.arange(5))
        npt.assert_array_equal(result.label, ['Mira', 'Classical_Cepheid',
                                              'Mira', 'Classical_Cepheid',
                                              'Mira'])
        pred_probs = result[['Classical_Cepheid', 'Mira']]
        npt.assert_array_equal(np.argmax(pred_probs.values, axis=1),
                               [1, 0, 1, 0, 1])
        assert (pred_probs.values >= 0.0).all()
    finally:
        os.remove('/tmp/cesium_prediction_results.csv')
test_TextAdapter.py 文件源码 项目:TextAdapter 作者: ContinuumIO 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def test_string_parsing(self):
        data = StringIO('1,2,3\n')
        adapter = textadapter.text_adapter(data, field_names=False)
        adapter.set_field_types({0:'S5', 1:'S5', 2:'S5'})
        assert_array_equal(adapter[:], np.array([('1', '2', '3')], dtype='S5,S5,S5'))

        data = io.StringIO(u'1,2,3\n')
        adapter = textadapter.text_adapter(data, field_names=False)
        adapter.set_field_types({0:'S5', 1:'S5', 2:'S5'})
        assert_array_equal(adapter[:], np.array([('1', '2', '3')], dtype='S5,S5,S5'))

        data = io.BytesIO(b'1,2,3\n')
        adapter = textadapter.text_adapter(data, field_names=False)
        adapter.set_field_types({0:'S5', 1:'S5', 2:'S5'})
        assert_array_equal(adapter[:], np.array([('1', '2', '3')], dtype='S5,S5,S5'))

    # basic utf_8 tests
test_TextAdapter.py 文件源码 项目:TextAdapter 作者: ContinuumIO 项目源码 文件源码 阅读 15 收藏 0 点赞 0 评论 0
def test_no_whitespace_stripping(self):
        data = StringIO('1  ,2  ,3  \n')
        adapter = textadapter.text_adapter(data, field_names=False)
        adapter.set_field_types({0:'S3', 1:'S3', 2:'S3'})
        assert_array_equal(adapter[:], np.array([('1  ', '2  ', '3  ')], dtype='S3,S3,S3'))

        data = StringIO('  1,  2,  3\n')
        adapter = textadapter.text_adapter(data, field_names=False)
        adapter.set_field_types({0:'S3', 1:'S3', 2:'S3'})
        assert_array_equal(adapter[:], np.array([('  1', '  2', '  3')], dtype='S3,S3,S3'))

        data = StringIO('  1  ,  2  ,  3  \n')
        adapter = textadapter.text_adapter(data, field_names=False)
        adapter.set_field_types({0:'S5', 1:'S5', 2:'S5'})
        assert_array_equal(adapter[:], np.array([('  1  ', '  2  ', '  3  ')], dtype='S5,S5,S5'))

        data = StringIO('\t1\t,\t2\t,\t3\t\n')
        adapter = textadapter.text_adapter(data, field_names=False)
        adapter.set_field_types({0:'S3', 1:'S3', 2:'S3'})
        assert_array_equal(adapter[:], np.array([('\t1\t', '\t2\t', '\t3\t')], dtype='S3,S3,S3'))
test_TextAdapter.py 文件源码 项目:TextAdapter 作者: ContinuumIO 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_header_footer(self):
        data = StringIO('0,1,2,3,4\n5,6,7,8,9\n10,11,12,13,14')
        adapter = textadapter.text_adapter(data, header=1, field_names=False)
        adapter.field_types = dict(zip(range(5), ['u4']*5))
        assert_array_equal(adapter[:], np.array([(5,6,7,8,9), (10,11,12,13,14)],
            dtype='u4,u4,u4,u4,u4'))

        data.seek(0)
        adapter = textadapter.text_adapter(data, header=2, field_names=False)
        adapter.field_types = dict(zip(range(5), ['u4']*5))
        assert_array_equal(adapter[:], np.array([(10,11,12,13,14)],
            dtype='u4,u4,u4,u4,u4'))

        data.seek(0)
        adapter = textadapter.text_adapter(data, header=1, field_names=True)
        adapter.field_types = dict(zip(range(5), ['u4']*5))
        assert_array_equal(adapter[:], np.array([(10,11,12,13,14)],
            dtype=[('5','u4'),('6','u4'),('7','u4'),('8','u4'),('9','u4')]))
test_TextAdapter.py 文件源码 项目:TextAdapter 作者: ContinuumIO 项目源码 文件源码 阅读 16 收藏 0 点赞 0 评论 0
def test_delimiter(self):
        data = StringIO('1,2,3\n')
        adapter = textadapter.text_adapter(data, field_names=False)
        self.assert_equality(adapter[0].item(), (1,2,3))

        data = StringIO('1 2 3\n')
        adapter = textadapter.text_adapter(data, field_names=False)
        self.assert_equality(adapter[0].item(), (1,2,3))

        data = StringIO('1\t2\t3\n')
        adapter = textadapter.text_adapter(data, field_names=False)
        self.assert_equality(adapter[0].item(), (1,2,3))

        data = StringIO('1x2x3\n')
        adapter = textadapter.text_adapter(data, field_names=False)
        self.assert_equality(adapter[0].item(), (1,2,3))

        # Test no delimiter in single field csv data
        data = StringIO('aaa\nbbb\nccc')
        array = textadapter.text_adapter(data, field_names=False, delimiter=None)[:]
        assert_array_equal(array, np.array([('aaa',), ('bbb',), ('ccc',)], dtype=[('f0', 'O')]))
test_TextAdapter.py 文件源码 项目:TextAdapter 作者: ContinuumIO 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def test_field_names(self):
        # Test for ignoring of extra fields
        data = StringIO('f0,f1\n0,1,2\n3,4,5')
        adapter = textadapter.text_adapter(data, 'csv', delimiter=',', field_names=True)
        array = adapter.to_array()
        self.assert_equality(array.dtype.names, ('f0', 'f1'))
        self.assert_equality(array[0].item(), (0,1))
        self.assert_equality(array[1].item(), (3,4))

        # Test for duplicate field names
        data = StringIO('f0,field,field\n0,1,2\n3,4,5')
        adapter = textadapter.text_adapter(data, 'csv', delimiter=',', field_names=True, infer_types=False)
        adapter.set_field_types({0:'u4', 1:'u4', 2:'u4'})
        array = adapter.to_array()
        self.assert_equality(array.dtype.names, ('f0', 'field', 'field1'))

        # Test for field names list
        data = StringIO('0,1,2\n3,4,5')
        adapter = textadapter.text_adapter(data, field_names=['a', 'b', 'c'], infer_types=False)
        adapter.field_types = {0:'u4', 1:'u4', 2:'u4'}
        array = adapter[:]
        self.assertTrue(array.dtype.names == ('a', 'b', 'c'))
        assert_array_equal(array, np.array([(0,1,2), (3,4,5)], dtype=[('a', 'u4'), ('b', 'u4'), ('c', 'u4')]))
test_converter.py 文件源码 项目:torch2coreml 作者: prisma-ai 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_image_input(self):
        from _torch_converter import convert
        coreml_model = convert(
            self.model,
            [self.input.shape],
            input_names=['image'],
            image_input_names=['image'],
            preprocessing_args={
                'is_bgr': False,
                'red_bias': 0.0,
                'green_bias': 0.0,
                'blue_bias': 0.0,
                'image_scale': 0.5
            }
        )

        input_array = (np.random.rand(224, 224, 3) * 255).astype('uint8')
        input_image = Image.fromarray(input_array).convert('RGBA')
        output_array = coreml_model.predict({"image": input_image})["output"]
        output_array = output_array.transpose((1, 2, 0))
        npt.assert_array_equal(output_array, input_array * 0.5)
test_breaks.py 文件源码 项目:mizani 作者: has2k1 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def test_log_breaks():
    x = [2, 20, 2000]
    limits = min(x), max(x)
    breaks = log_breaks()(limits)
    npt.assert_array_equal(breaks, [1, 10, 100, 1000, 10000])

    breaks = log_breaks(3)(limits)
    npt.assert_array_equal(breaks, [1, 100, 10000])

    breaks = log_breaks()((10000, 10000))
    npt.assert_array_equal(breaks, [10000])

    breaks = log_breaks()((float('-inf'), float('inf')))
    assert len(breaks) == 0

    # When the limits are in the same order of magnitude
    breaks = log_breaks()([35, 60])
    assert len(breaks) > 0
    assert all([1 < b < 100 for b in breaks])

    breaks = log_breaks()([200, 800])
    assert len(breaks) > 0
    assert all([10 < b < 1000 for b in breaks])
test_breaks.py 文件源码 项目:mizani 作者: has2k1 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_minor_breaks():
    # equidistant breaks
    major = [1, 2, 3, 4]
    limits = [0, 5]
    breaks = minor_breaks()(major, limits)
    npt.assert_array_equal(breaks, [.5, 1.5, 2.5, 3.5, 4.5])
    minor = minor_breaks(3)(major, [2, 3])
    npt.assert_array_equal(minor, [2.25, 2.5, 2.75])

    # non-equidistant breaks
    major = [1, 2, 4, 8]
    limits = [0, 10]
    minor = minor_breaks()(major, limits)
    npt.assert_array_equal(minor, [1.5, 3, 6])

    # single major break
    minor = minor_breaks()([2], limits)
    assert len(minor) == 0
test_breaks.py 文件源码 项目:mizani 作者: has2k1 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def test_date_breaks():
    # cpython
    x = [datetime(year, 1, 1) for year in [2010, 2026, 2015]]
    limits = min(x), max(x)

    breaks = date_breaks('5 Years')
    years = [d.year for d in breaks(limits)]
    npt.assert_array_equal(
        years, [2010, 2015, 2020, 2025, 2030])

    breaks = date_breaks('10 Years')
    years = [d.year for d in breaks(limits)]
    npt.assert_array_equal(years, [2010, 2020, 2030])

    # numpy
    x = [np.datetime64(i*10, 'D') for i in range(1, 10)]
    breaks = date_breaks('10 Years')
    limits = min(x), max(x)
    with pytest.raises(AttributeError):
        breaks(limits)

    # NaT
    limits = np.datetime64('NaT'), datetime(2017, 1, 1)
    breaks = date_breaks('10 Years')
    assert len(breaks(limits)) == 0
test_breaks.py 文件源码 项目:mizani 作者: has2k1 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def test_extended_breaks():
    x = np.arange(100)
    limits = min(x), max(x)
    for n in (5, 7, 10, 13, 31):
        breaks = extended_breaks(n=n)
        assert len(breaks(limits)) <= n+1

    # Reverse limits
    breaks = extended_breaks(n=7)
    npt.assert_array_equal(breaks((0, 6)), breaks((6, 0)))

    # Infinite limits
    limits = float('-inf'), float('inf')
    breaks = extended_breaks(n=5)
    assert len(breaks(limits)) == 0

    # Zero range discrete
    limits = [1, 1]
    assert len(breaks(limits)) == 1
    assert breaks(limits)[0] == limits[1]

    # Zero range continuous
    limits = [np.pi, np.pi]
    assert len(breaks(limits)) == 1
    assert breaks(limits)[0] == limits[1]
core.py 文件源码 项目:zipline-chinese 作者: zhanghan1990 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def check_arrays(x, y, err_msg='', verbose=True, check_dtypes=True):
    """
    Wrapper around np.testing.assert_array_equal that also verifies that inputs
    are ndarrays.

    See Also
    --------
    np.assert_array_equal
    """
    assert type(x) == type(y), "{x} != {y}".format(x=type(x), y=type(y))
    assert x.dtype == y.dtype, "{x.dtype} != {y.dtype}".format(x=x, y=y)

    return assert_array_equal(x, y, err_msg=err_msg, verbose=True)


问题


面经


文章

微信
公众号

扫码关注公众号