python类color_palette()的实例源码

barplot.py 文件源码 项目:coquery 作者: gkunter 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def set_defaults(self):
        """
        Set the plot defaults.
        """
        # choose the "Paired" palette if the number of grouping factor
        # levels is even and below 13, or the "Set3" palette otherwise:
        if len(self._levels[1 if len(self._groupby) == 2 else 0]) in (2, 4, 6, 8, 12):
            self.options["color_palette"] = "Paired"
        else:
            # use 'Set3', a quantitative palette, if there are two grouping
            # factors, or a palette diverging from Red to Purple otherwise:
            if len(self._groupby) == 2:
                self.options["color_palette"] = "Set3"
            else:
                self.options["color_palette"] = "RdPu"
        super(Visualizer, self).set_defaults()

        if self.percentage:
            self.options["label_x_axis"] = "Percentage"
        else:
            self.options["label_x_axis"] = "Frequency"

        session = options.cfg.main_window.Session

        if len(self._groupby) == 2:
            self.options["label_y_axis"] = session.translate_header(self._groupby[0])
            self.options["label_legend"] = session.translate_header(self._groupby[1])
        else:
            self.options["label_legend"] = session.translate_header(self._groupby[0])
            if self.percentage:
                self.options["label_y_axis"] = ""
            else:
                self.options["label_y_axis"] = session.translate_header(self._groupby[0])
barcodeplot.py 文件源码 项目:coquery 作者: gkunter 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def set_defaults(self):
        self.options["color_palette"] = "Paired"
        self.options["color_number"] = len(self._levels[0])
        super(Visualizer, self).set_defaults()
        self.options["label_x_axis"] = "Corpus position"
        if not self._levels or len(self._levels[0]) < 2:
            self.options["label_y_axis"] = ""
        else:
            self.options["label_y_axis"] = self._groupby[0]
visualizer.py 文件源码 项目:coquery 作者: gkunter 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def set_defaults(self):
        if self.numerical_axes and False:
            if not self.options.get("color_number"):
                self.options["color_number"] = 1
            if not self.options.get("label_legend_columns"):
                self.options["label_legend_columns"] = 1
            if not self.options.get("color_palette"):
                self.options["color_palette"] = "Paired"
                self.options["color_number"] = 1
        else:
            if not self.options.get("color_number"):
                self.options["color_number"] = len(self._levels[-1])
            if not self.options.get("label_legend_columns"):
                self.options["label_legend_columns"] = 1
            if not self.options.get("color_palette"):
                if len(self._levels) == 0:
                    self.options["color_palette"] = "Paired"
                    self.options["color_number"] = 1
                elif len(self._levels[-1]) in (2, 4, 6):
                    self.options["color_palette"] = "Paired"
                elif len(self._groupby) == 2:
                    self.options["color_palette"] = "Paired"
                else:
                    self.options["color_palette"] = "RdPu"

        self.options["figure_font"] = (
            QtWidgets.QApplication.instance().font())

        if not self.options.get("color_palette_values"):
            self.set_palette_values(self.options["color_number"])
visualizer.py 文件源码 项目:coquery 作者: gkunter 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def set_palette_values(self, n=None):
        """
        Set the color palette values to the specified number.
        """
        if not n:
            n = self.options["color_number"]
        else:
            self.options["color_number"] = n

        if self.options["color_palette"] != "custom":
            self.options["color_palette_values"] = sns.color_palette(
                self.options["color_palette"], n)
visualizationdesigner.py 文件源码 项目:coquery 作者: gkunter 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def show_palette(self):
        self.ui.color_test_area.clear()
        #test_numbers = self.ui.spin_number.value()
        test_numbers = 12
        test_palette = sns.color_palette(self._palette_name, test_numbers)
        for i, (r, g, b)in enumerate(test_palette):
            item = QtWidgets.QListWidgetItem()
            self.ui.color_test_area.addItem(item)
            brush = QtGui.QBrush(QtGui.QColor(
                        int(r * 255), int(g * 255), int(b * 255)))
            item.setBackground(brush)
figureoptions.py 文件源码 项目:coquery 作者: gkunter 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_palette(self):
        if self.palette_name == "custom":
            palette = self.custom_palette
        else:
            palette = sns.color_palette(self.palette_name, int(self.ui.spin_number.value()))
        self.ui.color_test_area.clear()
        for color in palette:
            item = CoqColorItem(color)
            self.ui.color_test_area.addItem(item)
matplotlib.py 文件源码 项目:physt 作者: janpipek 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def _get_cmap(kwargs):
    """Get the colour map for plots that support it.

    Parameters
    ----------
    cmap : str or colors.Colormap or list of colors
        A map or an instance of cmap. This can also be a seaborn palette
        (if seaborn is installed).

    Returns
    -------
    colors.Colormap
    """
    from matplotlib.colors import ListedColormap

    cmap = kwargs.pop("cmap", default_cmap)
    if isinstance(cmap, list):
        return ListedColormap(cmap)
    if isinstance(cmap, str):
        try:
            cmap = plt.get_cmap(cmap)
        except BaseException as exc:
            try:
                # Try to use seaborn palette
                import seaborn as sns
                sns_palette = sns.color_palette(cmap, n_colors=256)
                cmap = ListedColormap(sns_palette, name=cmap)
            except ImportError:
                raise exc
    return cmap
result_plotter.py 文件源码 项目:tensorforce-benchmark 作者: reinforceio 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def make_palette(self):
        if not self.palette:
            self.palette = sns.color_palette("husl", len(self.benchmarks))
information_value.py 文件源码 项目:score_card_base_python 作者: zzstrwolf 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def plot_br_chart(self,column):
        if type(self.woe_dicts[column].items()[0][0]) == str:
            woe_lists = sorted(self.woe_dicts[column].items(), key = self.sort_dict)
        else:
            woe_lists = sorted(self.woe_dicts[column].items(),key = lambda item:item[0])
        sns.set_style(rc={"axes.facecolor": "#EAEAF2",
                "axes.edgecolor": "#EAEAF2",
                "axes.linewidth": 1,
                "grid.color": "white",})
        tick_label = [i[0] for i in woe_lists]
        counts = [i[1][1] for i in woe_lists]
        br_data = [i[1][2] for i in woe_lists]
        x = range(len(counts))
        fig, ax1 = plt.subplots(figsize=(12,8))
        my_palette = sns.color_palette(n_colors=100)
        sns.barplot(x,counts,ax=ax1,palette=sns.husl_palette(n_colors=20,l=.7))
        plt.xticks(x,tick_label,rotation = 30,fontsize=12)
        plt.title(column,fontsize=18)
        ax1.set_ylabel('count',fontsize=15)
        ax1.tick_params('y',direction='in',length=6, width=0.5, labelsize=12)
        #ax1.bar(x,counts,tick_label = tick_label,color = 'y',align = 'center')
        #ax1.bar(x,counts,color = 'y',align = 'center')

        ax2 = ax1.twinx()
        ax2.plot(x,br_data,color='black')
        ax2.set_ylabel('bad rate',fontsize=15)
        ax2.tick_params('y',direction='in',length=6, width=0.5, labelsize=12)
        plot_margin = 0.25
        x0, x1, y0, y1 = ax1.axis()
        ax1.axis((x0 - plot_margin,
              x1 + plot_margin,
              y0 - 0,
              y1 * 1.1))
        plt.show()
information_value.py 文件源码 项目:score_card_base_python 作者: zzstrwolf 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def save_br_chart(self, column, path):
        if type(self.woe_dicts[column].items()[0][0]) == str:
            woe_lists = sorted(self.woe_dicts[column].items(), key = self.sort_dict)
        else:
            woe_lists = sorted(self.woe_dicts[column].items(),key = lambda item:item[0])
        tick_label = [i[0] for i in woe_lists]
        counts = [i[1][1] for i in woe_lists]
        br_data = [i[1][2] for i in woe_lists]
        x = range(len(counts))
        fig, ax1 = plt.subplots(figsize=(12,8))
        my_palette = sns.color_palette(n_colors=100)
        sns.barplot(x,counts,ax=ax1,palette=sns.husl_palette(n_colors=20,l=.7))
        plt.xticks(x,tick_label,rotation = 30,fontsize=12)
        plt.title(column,fontsize=18)
        ax1.set_ylabel('count',fontsize=15)
        ax1.tick_params('y',labelsize=12)
        ax2 = ax1.twinx()
        ax2.plot(x,br_data,color='black')
        ax2.set_ylabel('bad rate',fontsize=15)
        ax2.tick_params('y',labelsize=12)
        plot_margin = 0.25
        x0, x1, y0, y1 = ax1.axis()
        ax1.axis((x0 - plot_margin,
              x1 + plot_margin,
              y0 - 0,
              y1 * 1.1))
        plt.savefig(path)
fmriplots.py 文件源码 项目:mriqc 作者: poldracklab 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def plot(self):
        nconfounds = len(self.confounds)
        nspikes = len(self.spikes)
        nrows = 1 + nconfounds + nspikes

        # Create grid
        grid = mgs.GridSpec(nrows, 1, wspace=0.0, hspace=0.2,
                            height_ratios=[1] * (nrows - 1) + [3.5])

        grid_id = 0
        for tsz, name, iszs in self.spikes:
            spikesplot(tsz, title=name, outer_gs=grid[grid_id], tr=self.tr,
                       zscored=iszs)
            grid_id += 1

        if self.confounds:
            palette = color_palette("husl", nconfounds)

        for i, (tseries, kwargs) in enumerate(self.confounds):
            confoundplot(
                tseries, grid[grid_id], tr=self.tr, color=palette[i],
                **kwargs)
            grid_id += 1

        fmricarpetplot(self.func_data, self.seg_data,
                       grid[-1], tr=self.tr)

        setattr(self, 'grid', grid)
        # spikesplot_cb([0.7, 0.78, 0.2, 0.008])
plots.py 文件源码 项目:Comparative-Annotation-Toolkit 作者: ComparativeGenomicsToolkit 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def generic_unstacked_barplot(df, pdf, title_string, legend_labels, ylabel, names, box_label,
                              bbox_to_anchor=(1.12, 0.7)):
    fig, ax = plt.subplots()
    bars = []
    shorter_bar_width = bar_width / len(df)
    for i, (_, d) in enumerate(df.iterrows()):
        bars.append(ax.bar(np.arange(len(df.columns)) + shorter_bar_width * i, d, shorter_bar_width,
                           color=sns.color_palette()[i], linewidth=0.0))
    _generic_histogram(bars, legend_labels, title_string, pdf, ax, fig, ylabel, names, box_label, bbox_to_anchor)
plots.py 文件源码 项目:Comparative-Annotation-Toolkit 作者: ComparativeGenomicsToolkit 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def generic_stacked_barplot(df, pdf, title_string, legend_labels, ylabel, names, box_label, bbox_to_anchor=(1.12, 0.7)):
    fig, ax = plt.subplots()
    bars = []
    cumulative = np.zeros(len(df.columns))
    color_palette = choose_palette(legend_labels)
    for i, (_, d) in enumerate(df.iterrows()):
        bars.append(ax.bar(np.arange(len(df.columns)), d, bar_width, bottom=cumulative,
                           color=color_palette[i], linewidth=0.0))
        cumulative += d
    _generic_histogram(bars, legend_labels, title_string, pdf, ax, fig, ylabel, names, box_label, bbox_to_anchor)


###
# Shared functions
###
plots.py 文件源码 项目:Comparative-Annotation-Toolkit 作者: ComparativeGenomicsToolkit 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def choose_palette(ordered_genomes):
    """choose palette in cases where genomes get different colors"""
    if len(ordered_genomes) <= 6:
        return sns.color_palette()
    else:
        return sns.color_palette("Set2", len(ordered_genomes))
deprecated_seasonal_flu.py 文件源码 项目:augur 作者: nextstrain 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def plot_frequencies(flu, gene, mutation=None, plot_regions=None, all_muts=False, ax=None, **kwargs):
    import seaborn as sns
    sns.set_style('whitegrid')
    cols = sns.color_palette()
    linestyles = ['-', '--', '-.', ':']
    if plot_regions is None:
        plot_regions=regions
    pivots = flu.pivots
    if ax is None:
        plt.figure()
        ax=plt.subplot(111)
    if type(mutation)==int:
        mutations = [x for x,freq in flu.mutation_frequencies[('global', gene)].iteritems()
                     if (x[0]==mutation)&(freq[0]<0.5 or all_muts)]
    elif mutation is not None:
        mutations = [mutation]
    else:
        mutations=None

    if mutations is None:
        for ri, region in enumerate(plot_regions):
            count=flu.mutation_frequency_counts[region]
            plt.plot(pivots, count, c=cols[ri%len(cols)], label=region)
    else:
        print("plotting mutations", mutations)
        for ri,region in enumerate(plot_regions):
            for mi,mut in enumerate(mutations):
                if mut in flu.mutation_frequencies[(region, gene)]:
                    freq = flu.mutation_frequencies[(region, gene)][mut]
                    err = flu.mutation_frequency_confidence[(region, gene)][mut]
                    c=cols[ri%len(cols)]
                    label_str = str(mut[0]+1)+mut[1]+', '+region
                    plot_trace(ax, pivots, freq, err, c=c,
                        ls=linestyles[mi%len(linestyles)],label=label_str, **kwargs)
                else:
                    print(mut, 'not found in region',region)
    ax.ticklabel_format(useOffset=False)
    ax.legend(loc=2)
deprecated_seasonal_flu.py 文件源码 项目:augur 作者: nextstrain 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def plot_sequence_count(flu, fname=None, fs=12):
    # make figure with region counts
    import seaborn as sns
    date_bins = pivots_to_dates(flu.pivots)
    sns.set_style('ticks')
    region_label = {'global': 'Global', 'NA': 'N America', 'AS': 'Asia', 'EU': 'Europe', 'OC': 'Oceania'}
    regions_abbr = ['global', 'NA', 'AS', 'EU', 'OC']
    region_colors = {r:col for r, col in zip(regions_abbr,
                                             sns.color_palette(n_colors=len(regions_abbr)))}
    fig, ax = plt.subplots(figsize=(8, 3))
    count_by_region = flu.mutation_frequency_counts
    drop = 3
    tmpcounts = np.zeros(len(flu.pivots[drop:]))
    plt.bar(date_bins[drop:], count_by_region['global'][drop:], width=18, \
            linewidth=0, label="Other", color="#bbbbbb", clip_on=False)
    for region in region_groups:
        if region!='global':
            plt.bar(date_bins[drop:], count_by_region[region][drop:],
                    bottom=tmpcounts, width=18, linewidth=0,
                    label=region_label[region], color=region_colors[region], clip_on=False)
            tmpcounts += count_by_region[region][drop:]
    make_date_ticks(ax, fs=fs)
    ax.set_ylabel('Sample count')
    ax.legend(loc=3, ncol=1, bbox_to_anchor=(1.02, 0.53))
    plt.subplots_adjust(left=0.1, right=0.82, top=0.94, bottom=0.22)
    sns.despine()
    if fname is not None:
        plt.savefig(fname)
deprecated_flu_prediction.py 文件源码 项目:augur 作者: nextstrain 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def plot_prediction(self):
        '''
        plots the global frequencies, the predicted frequencies, and the frequencies
        in the short interval used for learning.
        '''
        from matplotlib import pyplot as plt
        import seaborn as sns
        fig, axs = plt.subplots(1,2, figsize=(12,6))

        axs[0].plot(self.t_cut*np.ones(2), [0,1], lw=3, alpha=0.3, c='k', ls='--')
        axs[0].plot(self.current_prediction_interval[1]*np.ones(2), [0,1], lw=3, alpha=0.3, c='k')

        train_pivots = self.train_frequencies[self.current_prediction_interval][0]
        train_freqs = self.train_frequencies[self.current_prediction_interval][1]
        cols = sns.color_palette()
        future_pivots = self.global_pivots>train_pivots[-1]
        for node in self.predictions:
            if np.max(self.predictions[node][self.global_pivots>train_pivots[0]])>0.02:
                #print(self.predictions[t_cut_val][node])
                axs[0].plot(self.global_pivots[future_pivots],
                            self.predictions[node][future_pivots], ls='--', c=cols[node.clade%6])
                axs[0].plot(self.global_pivots, self.global_freqs[node.clade], ls='-', c=cols[node.clade%6])
                axs[0].plot(train_pivots, train_freqs[node.clade], ls='-.', c=cols[node.clade%6])

        axs[0].set_xlim(train_pivots[0]-2, train_pivots[-1]+2)
        dev = self.prediction_error()
        dev[~future_pivots]=0.0
        axs[1].plot(self.global_pivots, dev)
        axs[1].set_xlim(train_pivots[0], train_pivots[-1]+2)
        axs[1].set_ylim(0, 3)
plotter.py 文件源码 项目:harpreif 作者: harpribot 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def scatter(x, colors):
        # We choose a color palette with seaborn.
        palette = np.array(sea.color_palette("hls", 258))

        # We create a scatter plot.
        f = plt.figure(figsize=(8, 8))
        ax = plt.subplot(aspect='equal')
        sc = ax.scatter(x[:, 0], x[:, 1], lw=0, s=40,
                        c=palette[colors.astype(np.int)])
        plt.xlim(-25, 25)
        plt.ylim(-25, 25)
        ax.axis('off')
        ax.axis('tight')

        # We add the labels for each digit.
        txts = []
        for i in range(10):
            # Position of each label.
            xtext, ytext = np.median(x[colors == i, :], axis=0)
            txt = ax.text(xtext, ytext, str(i), fontsize=24)
            txt.set_path_effects([
                patheffects.Stroke(linewidth=5, foreground="w"),
                patheffects.Normal()])
            txts.append(txt)

        plt.show()
        return f, ax, sc, txts
summary_stats.py 文件源码 项目:fake_news 作者: bmassman 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def word_count_by_label(articles: pd.DataFrame):
    """Show graph of word counts by article label."""
    palette = sns.color_palette(palette='hls', n_colors=2)
    true_news_wc = articles[articles['labels'] == 0]['word_count']
    fake_news_wc = articles[articles['labels'] == 1]['word_count']
    sns.kdeplot(true_news_wc, bw=3, color=palette[0], label='True News')
    sns.kdeplot(fake_news_wc, bw=3, color=palette[1], label='Fake News')
    sns.plt.legend()
    sns.plt.show()
vis.py 文件源码 项目:facenet_pytorch 作者: liorshk 项目源码 文件源码 阅读 40 收藏 0 点赞 0 评论 0
def visual_feature_space(features, labels, num_classes, name_dict):
    num = len(labels)

    title_font = {'fontname':'Arial', 'size':'20', 'color':'black', 'weight':'normal',
              'verticalalignment':'bottom'} # Bottom vertical alignment for more space
    axis_font = {'fontname':'Arial', 'size':'20'}

    # draw
    palette = np.array(sns.color_palette("hls", num_classes))

    # We create a scatter plot.
    f = plt.figure(figsize=(8, 8))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(features[:,0], features[:,1], lw=0, s=40,
                    c=palette[labels.astype(np.int)])
    # ax.axis('off')
    # ax.axis('tight')

    # We add the labels for each digit.
    txts = []
    for i in range(num_classes):
        # Position of each label.
        xtext, ytext = np.median(features[labels == i, :], axis=0)
        txt = ax.text(xtext, ytext, name_dict[i])
        txt.set_path_effects([
            PathEffects.Stroke(linewidth=5, foreground="w"),
            PathEffects.Normal()])
        txts.append(txt)
    ax.set_xlabel('Activation of the 1st neuron', **axis_font)
    ax.set_ylabel('Activation of the 2nd neuron', **axis_font)
    ax.set_title('softmax_loss + center_loss', **title_font)
    ax.set_axis_bgcolor('grey')
    f.savefig('center_loss.png')
    plt.show()
    return f, ax, sc, txts


问题


面经


文章

微信
公众号

扫码关注公众号