python类setp()的实例源码

color.py 文件源码 项目:neurotools 作者: michaelerule 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def color_boxplot(bp,COLOR):
    '''
    The Boxplot defaults are awful.
    This is a little better
    '''
    pylab.setp(bp['boxes'], color=COLOR, edgecolor=COLOR)
    pylab.setp(bp['whiskers'], color=COLOR, ls='-', lw=1)
    pylab.setp(bp['caps'], color=COLOR, lw=1)
    pylab.setp(bp['fliers'], color=COLOR, ms=4)
    pylab.setp(bp['medians'], color=GATHER[-1], lw=1.5, solid_capstyle='butt')


####################################################################### Three isoluminance hue wheels at varying brightness
# Unfortunately the hue distribution is a bit off for these and they come
# out a little heavy in the red, gree, and blue. I don't reccommend using
# them
plot.py 文件源码 项目:spyking-circus 作者: spyking-circus 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def view_waveforms_clusters(data, halo, threshold, templates, amps_lim, n_curves=200, save=False):

    nb_templates = templates.shape[1]
    n_panels     = numpy.ceil(numpy.sqrt(nb_templates))
    mask         = numpy.where(halo > -1)[0]
    clust_idx    = numpy.unique(halo[mask])
    fig          = pylab.figure()    
    square       = True
    center       = len(data[0] - 1)//2
    for count, i in enumerate(xrange(nb_templates)):
        if square:
            pylab.subplot(n_panels, n_panels, count + 1)
            if (numpy.mod(count, n_panels) != 0):
                pylab.setp(pylab.gca(), yticks=[])
            if (count < n_panels*(n_panels - 1)):
                pylab.setp(pylab.gca(), xticks=[])

        subcurves = numpy.where(halo == clust_idx[count])[0]
        for k in numpy.random.permutation(subcurves)[:n_curves]:
            pylab.plot(data[k], '0.5')

        pylab.plot(templates[:, count], 'r')        
        pylab.plot(amps_lim[count][0]*templates[:, count], 'b', alpha=0.5)
        pylab.plot(amps_lim[count][1]*templates[:, count], 'b', alpha=0.5)

        xmin, xmax = pylab.xlim()
        pylab.plot([xmin, xmax], [-threshold, -threshold], 'k--')
        pylab.plot([xmin, xmax], [threshold, threshold], 'k--')
        #pylab.ylim(-1.5*threshold, 1.5*threshold)
        ymin, ymax = pylab.ylim()
        pylab.plot([center, center], [ymin, ymax], 'k--')
        pylab.title('Cluster %d' %i)

    if nb_templates > 0:
        pylab.tight_layout()
    if save:
        pylab.savefig(os.path.join(save[0], 'waveforms_%s' %save[1]))
        pylab.close()
    else:
        pylab.show()
    del fig
plot.py 文件源码 项目:spyking-circus 作者: spyking-circus 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def view_raw_templates(file_name, n_temp=2, square=True):

    N_e, N_t, N_tm = templates.shape
    if not numpy.iterable(n_temp):
        if square:
            idx = numpy.random.permutation(numpy.arange(N_tm//2))[:n_temp**2]
        else:
            idx = numpy.random.permutation(numpy.arange(N_tm//2))[:n_temp]
    else:
        idx = n_temp

    import matplotlib.colors as colors
    my_cmap   = pylab.get_cmap('winter')
    cNorm     = colors.Normalize(vmin=0, vmax=N_e)
    scalarMap = pylab.cm.ScalarMappable(norm=cNorm, cmap=my_cmap)

    pylab.figure()
    for count, i in enumerate(idx):
        if square:
            pylab.subplot(n_temp, n_temp, count + 1)
            if (numpy.mod(count, n_temp) != 0):
                pylab.setp(pylab.gca(), yticks=[])
            if (count < n_temp*(n_temp - 1)):
                pylab.setp(pylab.gca(), xticks=[])
        else:
            pylab.subplot(len(idx), 1, count + 1)
            if count != (len(idx) - 1):
                pylab.setp(pylab.gca(), xticks=[])
        for j in xrange(N_e):
            colorVal = scalarMap.to_rgba(j)
            pylab.plot(templates[j, :, i], color=colorVal)

        pylab.title('Template %d' %i)
    pylab.tight_layout()
    pylab.show()
plot.py 文件源码 项目:privcount 作者: privcount 项目源码 文件源码 阅读 70 收藏 0 点赞 0 评论 0
def plot_bar_chart(page, datasets, dataset_labels, dataset_colors, x_group_labels, err=0, title=None, xlabel='Bins', ylabel='Counts'):
    assert len(datasets) == len(dataset_colors) == len(dataset_labels)
    for dataset in datasets:
        assert len(dataset) == len(datasets[0])
        assert len(dataset) == len(x_group_labels)

    num_x_groups = len(datasets[0])
    x_group_locations = pylab.arange(num_x_groups)
    width = 1.0 / float(len(datasets)+1)

    figure = pylab.figure()
    axis = figure.add_subplot(111)
    bars = []

    for i in xrange(len(datasets)):
        bar = axis.bar(x_group_locations + (width*i), datasets[i], width, yerr=err, color=dataset_colors[i], error_kw=dict(ecolor='pink', lw=3, capsize=6, capthick=3))
        bars.append(bar)

    if title is not None:
        axis.set_title(title)
    if ylabel is not None:
        axis.set_ylabel(ylabel)
    if xlabel is not None:
        axis.set_xlabel(xlabel)

    axis.set_xticks(x_group_locations + width*len(datasets)/2)
    x_tick_names = axis.set_xticklabels(x_group_labels)
    rot = 0 if num_x_groups == 1 else 15
    pylab.setp(x_tick_names, rotation=rot, fontsize=10)
    axis.set_xlim(-width, num_x_groups)
    y_tick_names = axis.get_yticklabels()
    pylab.setp(y_tick_names, rotation=0, fontsize=10)

    axis.legend([bar[0] for bar in bars], dataset_labels)
    page.savefig()
    pylab.close()
two_sigma_financial_modelling.py 文件源码 项目:PortfolioTimeSeriesAnalysis 作者: MizioAnd 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def dendrogram(df, number_of_clusters=int(df.shape[1] / 1.2)):
        # Create Dendrogram
        agglomerated_features = FeatureAgglomeration(n_clusters=number_of_clusters)
        used_networks = np.arange(0, number_of_clusters, dtype=int)

        # Create a custom palette to identify the networks
        network_pal = sns.cubehelix_palette(len(used_networks),
                                            light=.9, dark=.1, reverse=True,
                                            start=1, rot=-2)
        network_lut = dict(zip(map(str, df.columns), network_pal))

        # Convert the palette to vectors that will be drawn on the side of the matrix
        networks = df.columns.get_level_values(None)
        network_colors = pd.Series(networks, index=df.columns).map(network_lut)
        sns.set(font="monospace")
        # Create custom colormap
        cmap = sns.diverging_palette(h_neg=210, h_pos=350, s=90, l=30, as_cmap=True)
        cg = sns.clustermap(df.astype(float).corr(), cmap=cmap, linewidths=.5, row_colors=network_colors,
                            col_colors=network_colors)
        plt.setp(cg.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)
        plt.setp(cg.ax_heatmap.xaxis.get_majorticklabels(), rotation=90)
        plt.show()
dvhcalc.py 文件源码 项目:DVH 作者: glucee 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def main():


    # Read the example RT structure and RT dose files
    # The testdata was downloaded from the dicompyler website as testdata.zip

    # Obtain the structures and DVHs from the DICOM data

    rtssfile = 'testdata/rtss.dcm'
    rtdosefile = 'testdata/rtdose.dcm'
    RTss = dicomparser.DicomParser(rtssfile)
    #RTdose = dicomparser.DicomParser("testdata/rtdose.dcm") 
    RTstructures = RTss.GetStructures()

    # Generate the calculated DVHs
    calcdvhs = {}
    for key, structure in RTstructures.iteritems():
        calcdvhs[key] = dvhcalc.get_dvh(rtssfile, rtdosefile, key)
        if (key in calcdvhs) and (len(calcdvhs[key].counts) and calcdvhs[key].counts[0]!=0):
            print ('DVH found for ' + structure['name'])
            pl.plot(calcdvhs[key].counts * 100/calcdvhs[key].counts[0], 
                    color=dvhcalc.np.array(structure['color'], dtype=float) / 255, 
                    label=structure['name'], 
                    linestyle='dashed')
        #else: 
        #    print("%d: no DVH"%key)
    pl.xlabel('Distance (cm)')
    pl.ylabel('Percentage Volume')
    pl.legend(loc=7, borderaxespad=-5)
    pl.setp(pl.gca().get_legend().get_texts(), fontsize='x-small')
    pl.savefig('testdata/dvh.png', dpi = 75)
house_prices.py 文件源码 项目:HousePrices 作者: MizioAnd 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def dendrogram(df, number_of_clusters, agglomerated_feature_labels):
        import seaborn as sns
        # Todo: Create Dendrogram
        # used networks are the labels occuring in agglomerated_features.labels_
        # which corresponds to np.arange(0, number_of_clusters)
        # number_of_clusters = int(df.shape[1] / 1.2)
        # used_networks = np.arange(0, number_of_clusters, dtype=int)
        used_networks = np.unique(agglomerated_feature_labels)
        # used_networks = [1, 5, 6, 7, 8, 11, 12, 13, 16, 17]

        # In our case all columns are clustered, which means used_columns is true in every element
        # used_columns = (df.columns.get_level_values(None)
                        # .astype(int)
                        # .isin(used_networks))
        # used_columns = (agglomerated_feature_labels.astype(int).isin(used_networks))
        # df = df.loc[:, used_columns]

        # Create a custom palette to identify the networks
        network_pal = sns.cubehelix_palette(len(used_networks),
                                            light=.9, dark=.1, reverse=True,
                                            start=1, rot=-2)
        network_lut = dict(zip(map(str, df.columns), network_pal))

        # Convert the palette to vectors that will be drawn on the side of the matrix
        networks = df.columns.get_level_values(None)
        # networks = agglomerated_feature_labels
        network_colors = pd.Series(networks, index=df.columns).map(network_lut)
        # plt.figure()
        # cg = sns.clustermap(df, metric="correlation")
        # plt.setp(cg.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)
        sns.set(font="monospace")
        # Create custom colormap
        cmap = sns.diverging_palette(h_neg=210, h_pos=350, s=90, l=30, as_cmap=True)
        cg = sns.clustermap(df.astype(float).corr(), cmap=cmap, linewidths=.5, row_colors=network_colors,
                            col_colors=network_colors)
        plt.setp(cg.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)
        plt.setp(cg.ax_heatmap.xaxis.get_majorticklabels(), rotation=90)
        # plt.xticks(rotation=90)
        plt.show()
viewRecon.py 文件源码 项目:emc_and_dm 作者: eucall-software 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def make_panel_of_intensity_slices(fn, c_n=9):
    M.rcParams.update({'font.size': 13})
    intensList = read.extract_arr_from_h5(fn, "/history/intensities", n=c_n)
    quatList = read.extract_arr_from_h5(fn, "/history/quaternion", n=-1)
    P.ioff()
    intens_len  = len(intensList)
    sqrt_len    = int(N.sqrt(intens_len))
    intens_sh   = intensList[0].shape
    iter_labels = read.create_interval_labels(len(quatList), c_n)[:intens_len]
    to_plot     = intensList[:intens_len]
    quat_label  = quatList[N.array(iter_labels)-1][:intens_len]
    plot_titles = ["iter_%d, quat_%d"%(ii,jj) for ii,jj in zip(iter_labels, quat_label)]
    fig, ax     = P.subplots(sqrt_len, sqrt_len, sharex=True, sharey=True, figsize=(1.8*sqrt_len, 2.*sqrt_len))
    plt_counter = 0
    for r in range(sqrt_len):
        for c in range(sqrt_len):
            ax[r,c].set_title(plot_titles[plt_counter])
            curr_slice = to_plot[plt_counter][intens_sh[0]/2]
            curr_slice = curr_slice*(curr_slice>0.) + 1.E-8*(curr_slice<=0.)
            ax[r,c].set_title(plot_titles[plt_counter], fontsize=11.5)
            im = ax[r,c].imshow(N.log10(curr_slice), vmin=-6.5, vmax=-3.5, aspect='auto', cmap=P.cm.coolwarm)
            plt_counter += 1
    fig.subplots_adjust(wspace=0.01)
    (shx, shy) = curr_slice.shape
    (h_shx, h_shy) = (shx/2, shy/2)
    xt = N.linspace(0.5*h_shx, shx-.5*h_shx-1, 3).astype('int')
    xt_l = N.linspace(-0.5*h_shx, 0.5*h_shx, 3).astype('int')
    yt = N.linspace(0, shy-1, 3).astype('int')
    yt_l = N.linspace(-1*h_shy, h_shy, 3).astype('int')
    P.setp(ax, xticks=xt, xticklabels=xt_l, yticks=yt, yticklabels=yt_l)
    cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.8])
    fig.colorbar(im, cax=cbar_ax, label="log10(intensities)")
    img_name = "recon_series.pdf"
    P.savefig(img_name, bbox_inches='tight')
    print("Image has been saved as %s" % img_name)
    P.close(fig)
analyzer.py 文件源码 项目:spyking-circus-ort 作者: spyking-circus 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def view_synthetic_templates(self, indices=None, time=None, nn=100, hf_dist=45, a_dist=1.0):

        if indices is None:
            indices = range(self.nb_cells)

        if not numpy.iterable(indices):
            indices = [indices]

        scaling = None
        pylab.figure()

        for i in indices:

            template   = self._get_synthetic_template(i, time, nn, hf_dist, a_dist)
            template   = template.toarray()
            width      = template.shape[1]
            xmin, xmax = self.probe.field_of_view['x_min'], self.probe.field_of_view['x_max']
            ymin, ymax = self.probe.field_of_view['y_min'], self.probe.field_of_view['y_max']
            if scaling is None:
                scaling= 10*numpy.max(numpy.abs(template))
            colorVal   = self._scalarMap_synthetic.to_rgba(i)

            for count, i in enumerate(xrange(self.nb_channels)):
                x, y     = self.probe.positions[:, i]
                xpadding = ((x - xmin)/(float(xmax - xmin) + 1))*(2*width)
                ypadding = ((y - ymin)/(float(ymax - ymin) + 1))*scaling
                pylab.plot(xpadding + numpy.arange(width), ypadding + template[i, :], color=colorVal)

        pylab.tight_layout()
        pylab.setp(pylab.gca(), xticks=[], yticks=[])
        pylab.xlim(xmin, 3*width)
        pylab.show()
analyzer.py 文件源码 项目:spyking-circus-ort 作者: spyking-circus 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def view_circus_templates(self, indices=None):

        if indices is None:
            indices = range(self.nb_templates)

        if not numpy.iterable(indices):
            indices = [indices]

        data      = self.template_store.get(indices, ['templates', 'norms'])
        width     = self.template_store.width
        templates = data.pop('templates').T
        norms     = data.pop('norms')
        scaling   = None
        pylab.figure()

        for count, i in enumerate(indices):

            template   = templates[count].toarray().reshape(self.nb_channels, width) * norms[count]
            xmin, xmax = self.probe.field_of_view['x_min'], self.probe.field_of_view['x_max']
            ymin, ymax = self.probe.field_of_view['y_min'], self.probe.field_of_view['y_max']
            if scaling is None:
                scaling= 10*numpy.max(numpy.abs(template))
            colorVal   = self._scalarMap_circus.to_rgba(i)

            for count, i in enumerate(xrange(self.nb_channels)):
                x, y     = self.probe.positions[:, i]
                xpadding = ((x - xmin)/(float(xmax - xmin) + 1))*(2*width)
                ypadding = ((y - ymin)/(float(ymax - ymin) + 1))*scaling
                pylab.plot(xpadding + numpy.arange(width), ypadding + template[i, :], color=colorVal)

        pylab.tight_layout()
        pylab.setp(pylab.gca(), xticks=[], yticks=[])
        pylab.xlim(xmin, 3*width)
        pylab.show()
graph.py 文件源码 项目:twitter-bot-detection 作者: franckbrignoli 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def boxplot(self, values_human, values_bot, title, path):
        fig = plt.figure()
        ax = fig.add_subplot(111)

        ax.yaxis.grid(True)
        ax.set_ylabel(title)

        ax.boxplot([values_human, values_bot], vert=True, patch_artist=True)

        pl.setp(ax, xticks=[1, 2], xticklabels=["Humans", "Bots"])
        pl.savefig(path)
plot.py 文件源码 项目:spyking-circus 作者: spyking-circus 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def view_templates(file_name, temp_id=0, best_elec=None, templates=None):

    params          = CircusParser(file_name)
    N_e             = params.getint('data', 'N_e')
    N_total         = params.getint('data', 'N_total')
    sampling_rate   = params.getint('data', 'sampling_rate')
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening  = params.getboolean('whitening', 'spatial')
    spike_thresh     = params.getfloat('detection', 'spike_thresh')
    file_out_suff    = params.get('data', 'file_out_suff')
    N_t              = params.getint('detection', 'N_t')
    nodes, edges     = get_nodes_and_edges(params)
    chunk_size       = N_t
    N_total          = params.getint('data', 'N_total')
    inv_nodes        = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)

    if templates is None:
        templates    = load_data(params, 'templates')
    clusters         = load_data(params, 'clusters')
    probe            = params.probe

    positions = {}
    for i in probe['channel_groups'].keys():
        positions.update(probe['channel_groups'][i]['geometry'])
    xmin = 0
    xmax = 0
    ymin = 0
    ymax = 0
    scaling = 10*numpy.max(numpy.abs(templates[:,temp_id].toarray().reshape(N_e, N_t)))
    for i in xrange(N_e):
        if positions[i][0] < xmin:
            xmin = positions[i][0]
        if positions[i][0] > xmax:
            xmax = positions[i][0]
        if positions[i][1] < ymin:
            ymin = positions[i][0]
        if positions[i][1] > ymax:
            ymax = positions[i][1]
    if best_elec is None:
        best_elec = clusters['electrodes'][temp_id]
    elif best_elec == 'auto':
        best_elec = numpy.argmin(numpy.min(templates[:, :, temp_id], 1))
    pylab.figure()
    for count, i in enumerate(xrange(N_e)):
        x, y     = positions[i]
        xpadding = ((x - xmin)/(float(xmax - xmin) + 1))*(2*N_t)
        ypadding = ((y - ymin)/(float(ymax - ymin) + 1))*scaling

        if i == best_elec:
            c='r'
        elif i in inv_nodes[edges[nodes[best_elec]]]:
            c='k'
        else: 
            c='0.5'
        pylab.plot(xpadding + numpy.arange(0, N_t), ypadding + templates[i, :, temp_id], color=c)
    pylab.tight_layout()
    pylab.setp(pylab.gca(), xticks=[], yticks=[])
    pylab.xlim(xmin, 3*N_t)
    pylab.show()    
    return best_elec
plot.py 文件源码 项目:spyking-circus 作者: spyking-circus 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def view_masks(file_name, t_start=0, t_stop=1, n_elec=0):

    params          = CircusParser(file_name)
    data_file       = params.get_data_file()
    data_file.open()
    N_e             = params.getint('data', 'N_e')
    N_t             = params.getint('detection', 'N_t')
    N_total         = params.nb_channels
    sampling_rate   = params.rate
    do_temporal_whitening = params.getboolean('whitening', 'temporal')
    do_spatial_whitening  = params.getboolean('whitening', 'spatial')
    spike_thresh     = params.getfloat('detection', 'spike_thresh')
    file_out_suff    = params.get('data', 'file_out_suff')
    nodes, edges     = get_nodes_and_edges(params)
    chunk_size       = (t_stop - t_start)*sampling_rate
    padding          = (t_start*sampling_rate, t_start*sampling_rate)
    inv_nodes        = numpy.zeros(N_total, dtype=numpy.int32)
    inv_nodes[nodes] = numpy.argsort(nodes)
    safety_time      = params.getint('clustering', 'safety_time')

    if do_spatial_whitening:
        spatial_whitening  = load_data(params, 'spatial_whitening')
    if do_temporal_whitening:
        temporal_whitening = load_data(params, 'temporal_whitening')

    thresholds       = load_data(params, 'thresholds')
    data = data_file.get_data(0, chunk_size, padding=padding, nodes=nodes)
    data_shape = len(data)
    data_file.close()
    peaks            = {}
    indices          = inv_nodes[edges[nodes[n_elec]]]

    if do_spatial_whitening:
        data = numpy.dot(data, spatial_whitening)
    if do_temporal_whitening: 
        data = scipy.ndimage.filters.convolve1d(data, temporal_whitening, axis=0, mode='constant')

    for i in xrange(N_e):
        peaks[i]   = algo.detect_peaks(data[:, i], thresholds[i], valley=True, mpd=0)


    pylab.figure()

    for count, i in enumerate(indices):

        pylab.plot(count*5 + data[:, i], '0.25')
        #xmin, xmax = pylab.xlim()
        pylab.scatter(peaks[i], count*5 + data[peaks[i], i], s=10, c='r')

    for count, i in enumerate(peaks[n_elec]):
        pylab.axvspan(i - safety_time, i + safety_time, facecolor='r', alpha=0.5)

    pylab.ylim(-5, len(indices)*5 )
    pylab.xlabel('Time [ms]')
    pylab.ylabel('Electrode')
    pylab.tight_layout()
    pylab.setp(pylab.gca(), yticks=[])
    pylab.show()
    return peaks
bayes.py 文件源码 项目:nmmn 作者: rsnemmen 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def allplot(xb,yb,bins=30,fig=1,xlabel='x',ylabel='y'):
    """
Input:
X,Y : objects referring to the variables produced by PyMC that you want
to analyze. Example: X=M.theta, Y=M.slope.

Inherited from Tommy LE BLANC's code at astroplotlib|STSCI.
    """
    #X,Y=xb.trace(),yb.trace()
    X,Y=xb,yb

    #pylab.rcParams.update({'font.size': fontsize})
    fig=pylab.figure(fig)
    pylab.clf()

    gs = pylab.GridSpec(2, 2, width_ratios=[3,1], height_ratios=[1,3], wspace=0.07, hspace=0.07)
    scat=pylab.subplot(gs[2])
    histx=pylab.subplot(gs[0], sharex=scat)
    histy=pylab.subplot(gs[3], sharey=scat)
    #scat=fig.add_subplot(2,2,3)
    #histx=fig.add_subplot(2,2,1, sharex=scat)
    #histy=fig.add_subplot(2,2,4, sharey=scat)

    # Scatter plot
    scat.plot(X, Y,linestyle='none', marker='o', color='green', mec='green',alpha=.2, zorder=-99)

    gkde = scipy.stats.gaussian_kde([X, Y])
    x,y = numpy.mgrid[X.min():X.max():(X.max()-X.min())/25.,Y.min():Y.max():(Y.max()-Y.min())/25.]
    z = numpy.array(gkde.evaluate([x.flatten(), y.flatten()])).reshape(x.shape)
    scat.contour(x, y, z, linewidths=2)
    scat.set_xlabel(xlabel)
    scat.set_ylabel(ylabel)

    # X-axis histogram
    histx.hist(X, bins, histtype='stepfilled')
    pylab.setp(histx.get_xticklabels(), visible=False)  # no X label
    #histx.xaxis.set_major_formatter(pylab.NullFormatter()) # no X label

    # Y-axis histogram
    histy.hist(Y, bins, histtype='stepfilled', orientation='horizontal')
    pylab.setp(histy.get_yticklabels(), visible=False)  # no Y label
    #histy.yaxis.set_major_formatter(pylab.NullFormatter()) # no Y label

    #pylab.minorticks_on()
    #pylab.subplots_adjust(hspace=0.1)
    #pylab.subplots_adjust(wspace=0.1)
    pylab.draw()
    pylab.show()
bayes.py 文件源码 项目:nmmn 作者: rsnemmen 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def jointplotx(X,Y,xlabel=None,ylabel=None,binsim=40,binsh=20,binscon=15):
    """
Plots the joint distribution of posteriors for X1 and X2, including the 1D
histograms showing the median and standard deviations. Uses simple method
for drawing the confidence contours compared to jointplot (which is wrong).

The work that went in creating this method is shown, step by step, in 
the ipython notebook "error contours.ipynb". Sources of inspiration:
- http://python4mpia.github.io/intro/quick-tour.html

Usage:
>>> jointplot(M.rtr.trace(),M.mdot.trace(),xlabel='$\log \ r_{\\rm tr}$', ylabel='$\log \ \dot{m}$')
    """
    # Generates 2D histogram for image
    histt, xt, yt = numpy.histogram2d(X, Y, bins=[binsim,binsim], normed=False)
    histt = numpy.transpose(histt)  # Beware: numpy switches axes, so switch back.

    # assigns correct proportions to subplots
    fig=pylab.figure()
    gs = pylab.GridSpec(2, 2, width_ratios=[3,1], height_ratios=[1,3], wspace=0.001, hspace=0.001)
    con=pylab.subplot(gs[2])
    histx=pylab.subplot(gs[0], sharex=con)
    histy=pylab.subplot(gs[3], sharey=con)

    # Image
    con.imshow(histt,extent=[xt[0],xt[-1], yt[0],yt[-1]],origin='lower',cmap=pylab.cm.gray_r,aspect='auto')

    # Overplot with error contours 1,2 sigma
    # Contour plot
    histdata, x, y = numpy.histogram2d(X, Y, bins=[binscon,binscon], normed=False)
    histdata = numpy.transpose(histdata)  # Beware: numpy switches axes, so switch back.
    pmax  = histdata.max()
    cs=con.contour(histdata, levels=[0.68*pmax,0.05*pmax], extent=[x[0],x[-1], y[0],y[-1]], colors=['black','blue'])
    # use dictionary in order to assign your own labels to the contours.
    #fmtdict = {s[0]:r'$1\sigma$',s[1]:r'$2\sigma$'}
    #con.clabel(cs, fmt=fmtdict, inline=True, fontsize=20)
    if xlabel!=None: con.set_xlabel(xlabel)
    if ylabel!=None: con.set_ylabel(ylabel)

    # X-axis histogram
    histx.hist(X, binsh, histtype='stepfilled',facecolor='lightblue')
    pylab.setp(histx.get_xticklabels(), visible=False)  # no X label
    pylab.setp(histx.get_yticklabels(), visible=False)  # no Y label
    # Vertical lines with median and 1sigma confidence
    yax=histx.set_ylim()
    histx.plot([numpy.median(X),numpy.median(X)],[yax[0],yax[1]],'k-',linewidth=2) # median
    xsd=scipy.stats.scoreatpercentile(X, [15.87,84.13])
    histx.plot([xsd[0],xsd[0]],[yax[0],yax[1]],'k--') # -1sd
    histx.plot([xsd[-1],xsd[-1]],[yax[0],yax[1]],'k--') # +1sd

    # Y-axis histogram
    histy.hist(Y, binsh, histtype='stepfilled', orientation='horizontal',facecolor='lightyellow')
    pylab.setp(histy.get_yticklabels(), visible=False)  # no Y label
    pylab.setp(histy.get_xticklabels(), visible=False)  # no X label
    # Vertical lines with median and 1sigma confidence
    xax=histy.set_xlim()
    histy.plot([xax[0],xax[1]],[numpy.median(Y),numpy.median(Y)],'k-',linewidth=2) # median
    ysd=scipy.stats.scoreatpercentile(Y, [15.87,84.13])
    histy.plot([xax[0],xax[1]],[ysd[0],ysd[0]],'k--') # -1sd
    histy.plot([xax[0],xax[1]],[ysd[-1],ysd[-1]],'k--') # +1sd


问题


面经


文章

微信
公众号

扫码关注公众号