def set_defaults(self):
"""
Set the plot defaults.
"""
# choose the "Paired" palette if the number of grouping factor
# levels is even and below 13, or the "Set3" palette otherwise:
if len(self._levels[1 if len(self._groupby) == 2 else 0]) in (2, 4, 6, 8, 12):
self.options["color_palette"] = "Paired"
else:
# use 'Set3', a quantitative palette, if there are two grouping
# factors, or a palette diverging from Red to Purple otherwise:
if len(self._groupby) == 2:
self.options["color_palette"] = "Set3"
else:
self.options["color_palette"] = "RdPu"
super(Visualizer, self).set_defaults()
if self.percentage:
self.options["label_x_axis"] = "Percentage"
else:
self.options["label_x_axis"] = "Frequency"
session = options.cfg.main_window.Session
if len(self._groupby) == 2:
self.options["label_y_axis"] = session.translate_header(self._groupby[0])
self.options["label_legend"] = session.translate_header(self._groupby[1])
else:
self.options["label_legend"] = session.translate_header(self._groupby[0])
if self.percentage:
self.options["label_y_axis"] = ""
else:
self.options["label_y_axis"] = session.translate_header(self._groupby[0])
python类color_palette()的实例源码
def set_defaults(self):
self.options["color_palette"] = "Paired"
self.options["color_number"] = len(self._levels[0])
super(Visualizer, self).set_defaults()
self.options["label_x_axis"] = "Corpus position"
if not self._levels or len(self._levels[0]) < 2:
self.options["label_y_axis"] = ""
else:
self.options["label_y_axis"] = self._groupby[0]
def set_defaults(self):
if self.numerical_axes and False:
if not self.options.get("color_number"):
self.options["color_number"] = 1
if not self.options.get("label_legend_columns"):
self.options["label_legend_columns"] = 1
if not self.options.get("color_palette"):
self.options["color_palette"] = "Paired"
self.options["color_number"] = 1
else:
if not self.options.get("color_number"):
self.options["color_number"] = len(self._levels[-1])
if not self.options.get("label_legend_columns"):
self.options["label_legend_columns"] = 1
if not self.options.get("color_palette"):
if len(self._levels) == 0:
self.options["color_palette"] = "Paired"
self.options["color_number"] = 1
elif len(self._levels[-1]) in (2, 4, 6):
self.options["color_palette"] = "Paired"
elif len(self._groupby) == 2:
self.options["color_palette"] = "Paired"
else:
self.options["color_palette"] = "RdPu"
self.options["figure_font"] = (
QtWidgets.QApplication.instance().font())
if not self.options.get("color_palette_values"):
self.set_palette_values(self.options["color_number"])
def set_palette_values(self, n=None):
"""
Set the color palette values to the specified number.
"""
if not n:
n = self.options["color_number"]
else:
self.options["color_number"] = n
if self.options["color_palette"] != "custom":
self.options["color_palette_values"] = sns.color_palette(
self.options["color_palette"], n)
def show_palette(self):
self.ui.color_test_area.clear()
#test_numbers = self.ui.spin_number.value()
test_numbers = 12
test_palette = sns.color_palette(self._palette_name, test_numbers)
for i, (r, g, b)in enumerate(test_palette):
item = QtWidgets.QListWidgetItem()
self.ui.color_test_area.addItem(item)
brush = QtGui.QBrush(QtGui.QColor(
int(r * 255), int(g * 255), int(b * 255)))
item.setBackground(brush)
def test_palette(self):
if self.palette_name == "custom":
palette = self.custom_palette
else:
palette = sns.color_palette(self.palette_name, int(self.ui.spin_number.value()))
self.ui.color_test_area.clear()
for color in palette:
item = CoqColorItem(color)
self.ui.color_test_area.addItem(item)
def _get_cmap(kwargs):
"""Get the colour map for plots that support it.
Parameters
----------
cmap : str or colors.Colormap or list of colors
A map or an instance of cmap. This can also be a seaborn palette
(if seaborn is installed).
Returns
-------
colors.Colormap
"""
from matplotlib.colors import ListedColormap
cmap = kwargs.pop("cmap", default_cmap)
if isinstance(cmap, list):
return ListedColormap(cmap)
if isinstance(cmap, str):
try:
cmap = plt.get_cmap(cmap)
except BaseException as exc:
try:
# Try to use seaborn palette
import seaborn as sns
sns_palette = sns.color_palette(cmap, n_colors=256)
cmap = ListedColormap(sns_palette, name=cmap)
except ImportError:
raise exc
return cmap
def make_palette(self):
if not self.palette:
self.palette = sns.color_palette("husl", len(self.benchmarks))
def plot_br_chart(self,column):
if type(self.woe_dicts[column].items()[0][0]) == str:
woe_lists = sorted(self.woe_dicts[column].items(), key = self.sort_dict)
else:
woe_lists = sorted(self.woe_dicts[column].items(),key = lambda item:item[0])
sns.set_style(rc={"axes.facecolor": "#EAEAF2",
"axes.edgecolor": "#EAEAF2",
"axes.linewidth": 1,
"grid.color": "white",})
tick_label = [i[0] for i in woe_lists]
counts = [i[1][1] for i in woe_lists]
br_data = [i[1][2] for i in woe_lists]
x = range(len(counts))
fig, ax1 = plt.subplots(figsize=(12,8))
my_palette = sns.color_palette(n_colors=100)
sns.barplot(x,counts,ax=ax1,palette=sns.husl_palette(n_colors=20,l=.7))
plt.xticks(x,tick_label,rotation = 30,fontsize=12)
plt.title(column,fontsize=18)
ax1.set_ylabel('count',fontsize=15)
ax1.tick_params('y',direction='in',length=6, width=0.5, labelsize=12)
#ax1.bar(x,counts,tick_label = tick_label,color = 'y',align = 'center')
#ax1.bar(x,counts,color = 'y',align = 'center')
ax2 = ax1.twinx()
ax2.plot(x,br_data,color='black')
ax2.set_ylabel('bad rate',fontsize=15)
ax2.tick_params('y',direction='in',length=6, width=0.5, labelsize=12)
plot_margin = 0.25
x0, x1, y0, y1 = ax1.axis()
ax1.axis((x0 - plot_margin,
x1 + plot_margin,
y0 - 0,
y1 * 1.1))
plt.show()
def save_br_chart(self, column, path):
if type(self.woe_dicts[column].items()[0][0]) == str:
woe_lists = sorted(self.woe_dicts[column].items(), key = self.sort_dict)
else:
woe_lists = sorted(self.woe_dicts[column].items(),key = lambda item:item[0])
tick_label = [i[0] for i in woe_lists]
counts = [i[1][1] for i in woe_lists]
br_data = [i[1][2] for i in woe_lists]
x = range(len(counts))
fig, ax1 = plt.subplots(figsize=(12,8))
my_palette = sns.color_palette(n_colors=100)
sns.barplot(x,counts,ax=ax1,palette=sns.husl_palette(n_colors=20,l=.7))
plt.xticks(x,tick_label,rotation = 30,fontsize=12)
plt.title(column,fontsize=18)
ax1.set_ylabel('count',fontsize=15)
ax1.tick_params('y',labelsize=12)
ax2 = ax1.twinx()
ax2.plot(x,br_data,color='black')
ax2.set_ylabel('bad rate',fontsize=15)
ax2.tick_params('y',labelsize=12)
plot_margin = 0.25
x0, x1, y0, y1 = ax1.axis()
ax1.axis((x0 - plot_margin,
x1 + plot_margin,
y0 - 0,
y1 * 1.1))
plt.savefig(path)
def plot(self):
nconfounds = len(self.confounds)
nspikes = len(self.spikes)
nrows = 1 + nconfounds + nspikes
# Create grid
grid = mgs.GridSpec(nrows, 1, wspace=0.0, hspace=0.2,
height_ratios=[1] * (nrows - 1) + [3.5])
grid_id = 0
for tsz, name, iszs in self.spikes:
spikesplot(tsz, title=name, outer_gs=grid[grid_id], tr=self.tr,
zscored=iszs)
grid_id += 1
if self.confounds:
palette = color_palette("husl", nconfounds)
for i, (tseries, kwargs) in enumerate(self.confounds):
confoundplot(
tseries, grid[grid_id], tr=self.tr, color=palette[i],
**kwargs)
grid_id += 1
fmricarpetplot(self.func_data, self.seg_data,
grid[-1], tr=self.tr)
setattr(self, 'grid', grid)
# spikesplot_cb([0.7, 0.78, 0.2, 0.008])
plots.py 文件源码
项目:Comparative-Annotation-Toolkit
作者: ComparativeGenomicsToolkit
项目源码
文件源码
阅读 29
收藏 0
点赞 0
评论 0
def generic_unstacked_barplot(df, pdf, title_string, legend_labels, ylabel, names, box_label,
bbox_to_anchor=(1.12, 0.7)):
fig, ax = plt.subplots()
bars = []
shorter_bar_width = bar_width / len(df)
for i, (_, d) in enumerate(df.iterrows()):
bars.append(ax.bar(np.arange(len(df.columns)) + shorter_bar_width * i, d, shorter_bar_width,
color=sns.color_palette()[i], linewidth=0.0))
_generic_histogram(bars, legend_labels, title_string, pdf, ax, fig, ylabel, names, box_label, bbox_to_anchor)
plots.py 文件源码
项目:Comparative-Annotation-Toolkit
作者: ComparativeGenomicsToolkit
项目源码
文件源码
阅读 30
收藏 0
点赞 0
评论 0
def generic_stacked_barplot(df, pdf, title_string, legend_labels, ylabel, names, box_label, bbox_to_anchor=(1.12, 0.7)):
fig, ax = plt.subplots()
bars = []
cumulative = np.zeros(len(df.columns))
color_palette = choose_palette(legend_labels)
for i, (_, d) in enumerate(df.iterrows()):
bars.append(ax.bar(np.arange(len(df.columns)), d, bar_width, bottom=cumulative,
color=color_palette[i], linewidth=0.0))
cumulative += d
_generic_histogram(bars, legend_labels, title_string, pdf, ax, fig, ylabel, names, box_label, bbox_to_anchor)
###
# Shared functions
###
plots.py 文件源码
项目:Comparative-Annotation-Toolkit
作者: ComparativeGenomicsToolkit
项目源码
文件源码
阅读 27
收藏 0
点赞 0
评论 0
def choose_palette(ordered_genomes):
"""choose palette in cases where genomes get different colors"""
if len(ordered_genomes) <= 6:
return sns.color_palette()
else:
return sns.color_palette("Set2", len(ordered_genomes))
def plot_frequencies(flu, gene, mutation=None, plot_regions=None, all_muts=False, ax=None, **kwargs):
import seaborn as sns
sns.set_style('whitegrid')
cols = sns.color_palette()
linestyles = ['-', '--', '-.', ':']
if plot_regions is None:
plot_regions=regions
pivots = flu.pivots
if ax is None:
plt.figure()
ax=plt.subplot(111)
if type(mutation)==int:
mutations = [x for x,freq in flu.mutation_frequencies[('global', gene)].iteritems()
if (x[0]==mutation)&(freq[0]<0.5 or all_muts)]
elif mutation is not None:
mutations = [mutation]
else:
mutations=None
if mutations is None:
for ri, region in enumerate(plot_regions):
count=flu.mutation_frequency_counts[region]
plt.plot(pivots, count, c=cols[ri%len(cols)], label=region)
else:
print("plotting mutations", mutations)
for ri,region in enumerate(plot_regions):
for mi,mut in enumerate(mutations):
if mut in flu.mutation_frequencies[(region, gene)]:
freq = flu.mutation_frequencies[(region, gene)][mut]
err = flu.mutation_frequency_confidence[(region, gene)][mut]
c=cols[ri%len(cols)]
label_str = str(mut[0]+1)+mut[1]+', '+region
plot_trace(ax, pivots, freq, err, c=c,
ls=linestyles[mi%len(linestyles)],label=label_str, **kwargs)
else:
print(mut, 'not found in region',region)
ax.ticklabel_format(useOffset=False)
ax.legend(loc=2)
def plot_sequence_count(flu, fname=None, fs=12):
# make figure with region counts
import seaborn as sns
date_bins = pivots_to_dates(flu.pivots)
sns.set_style('ticks')
region_label = {'global': 'Global', 'NA': 'N America', 'AS': 'Asia', 'EU': 'Europe', 'OC': 'Oceania'}
regions_abbr = ['global', 'NA', 'AS', 'EU', 'OC']
region_colors = {r:col for r, col in zip(regions_abbr,
sns.color_palette(n_colors=len(regions_abbr)))}
fig, ax = plt.subplots(figsize=(8, 3))
count_by_region = flu.mutation_frequency_counts
drop = 3
tmpcounts = np.zeros(len(flu.pivots[drop:]))
plt.bar(date_bins[drop:], count_by_region['global'][drop:], width=18, \
linewidth=0, label="Other", color="#bbbbbb", clip_on=False)
for region in region_groups:
if region!='global':
plt.bar(date_bins[drop:], count_by_region[region][drop:],
bottom=tmpcounts, width=18, linewidth=0,
label=region_label[region], color=region_colors[region], clip_on=False)
tmpcounts += count_by_region[region][drop:]
make_date_ticks(ax, fs=fs)
ax.set_ylabel('Sample count')
ax.legend(loc=3, ncol=1, bbox_to_anchor=(1.02, 0.53))
plt.subplots_adjust(left=0.1, right=0.82, top=0.94, bottom=0.22)
sns.despine()
if fname is not None:
plt.savefig(fname)
def plot_prediction(self):
'''
plots the global frequencies, the predicted frequencies, and the frequencies
in the short interval used for learning.
'''
from matplotlib import pyplot as plt
import seaborn as sns
fig, axs = plt.subplots(1,2, figsize=(12,6))
axs[0].plot(self.t_cut*np.ones(2), [0,1], lw=3, alpha=0.3, c='k', ls='--')
axs[0].plot(self.current_prediction_interval[1]*np.ones(2), [0,1], lw=3, alpha=0.3, c='k')
train_pivots = self.train_frequencies[self.current_prediction_interval][0]
train_freqs = self.train_frequencies[self.current_prediction_interval][1]
cols = sns.color_palette()
future_pivots = self.global_pivots>train_pivots[-1]
for node in self.predictions:
if np.max(self.predictions[node][self.global_pivots>train_pivots[0]])>0.02:
#print(self.predictions[t_cut_val][node])
axs[0].plot(self.global_pivots[future_pivots],
self.predictions[node][future_pivots], ls='--', c=cols[node.clade%6])
axs[0].plot(self.global_pivots, self.global_freqs[node.clade], ls='-', c=cols[node.clade%6])
axs[0].plot(train_pivots, train_freqs[node.clade], ls='-.', c=cols[node.clade%6])
axs[0].set_xlim(train_pivots[0]-2, train_pivots[-1]+2)
dev = self.prediction_error()
dev[~future_pivots]=0.0
axs[1].plot(self.global_pivots, dev)
axs[1].set_xlim(train_pivots[0], train_pivots[-1]+2)
axs[1].set_ylim(0, 3)
def scatter(x, colors):
# We choose a color palette with seaborn.
palette = np.array(sea.color_palette("hls", 258))
# We create a scatter plot.
f = plt.figure(figsize=(8, 8))
ax = plt.subplot(aspect='equal')
sc = ax.scatter(x[:, 0], x[:, 1], lw=0, s=40,
c=palette[colors.astype(np.int)])
plt.xlim(-25, 25)
plt.ylim(-25, 25)
ax.axis('off')
ax.axis('tight')
# We add the labels for each digit.
txts = []
for i in range(10):
# Position of each label.
xtext, ytext = np.median(x[colors == i, :], axis=0)
txt = ax.text(xtext, ytext, str(i), fontsize=24)
txt.set_path_effects([
patheffects.Stroke(linewidth=5, foreground="w"),
patheffects.Normal()])
txts.append(txt)
plt.show()
return f, ax, sc, txts
def word_count_by_label(articles: pd.DataFrame):
"""Show graph of word counts by article label."""
palette = sns.color_palette(palette='hls', n_colors=2)
true_news_wc = articles[articles['labels'] == 0]['word_count']
fake_news_wc = articles[articles['labels'] == 1]['word_count']
sns.kdeplot(true_news_wc, bw=3, color=palette[0], label='True News')
sns.kdeplot(fake_news_wc, bw=3, color=palette[1], label='Fake News')
sns.plt.legend()
sns.plt.show()
def visual_feature_space(features, labels, num_classes, name_dict):
num = len(labels)
title_font = {'fontname':'Arial', 'size':'20', 'color':'black', 'weight':'normal',
'verticalalignment':'bottom'} # Bottom vertical alignment for more space
axis_font = {'fontname':'Arial', 'size':'20'}
# draw
palette = np.array(sns.color_palette("hls", num_classes))
# We create a scatter plot.
f = plt.figure(figsize=(8, 8))
ax = plt.subplot(aspect='equal')
sc = ax.scatter(features[:,0], features[:,1], lw=0, s=40,
c=palette[labels.astype(np.int)])
# ax.axis('off')
# ax.axis('tight')
# We add the labels for each digit.
txts = []
for i in range(num_classes):
# Position of each label.
xtext, ytext = np.median(features[labels == i, :], axis=0)
txt = ax.text(xtext, ytext, name_dict[i])
txt.set_path_effects([
PathEffects.Stroke(linewidth=5, foreground="w"),
PathEffects.Normal()])
txts.append(txt)
ax.set_xlabel('Activation of the 1st neuron', **axis_font)
ax.set_ylabel('Activation of the 2nd neuron', **axis_font)
ax.set_title('softmax_loss + center_loss', **title_font)
ax.set_axis_bgcolor('grey')
f.savefig('center_loss.png')
plt.show()
return f, ax, sc, txts