def plot(self, df, database_name, test_name, y_label):
means = df.rolling(70).mean()
ax = means.plot(
title=test_name, alpha=0.8,
xlim=(0, means.index.max() * 1.05),
ylim=(0, means.max().max() * 1.05),
)
ax.set(xlabel='Amount of objects in table', ylabel=y_label)
ax.xaxis.set_major_formatter(
FuncFormatter(lambda v, pos: prefix_unit(v, '', -3)))
if y_label in self.ticks_formatters:
ax.yaxis.set_major_formatter(self.ticks_formatters[y_label])
legend = ax.legend(
loc='upper center', bbox_to_anchor=(0.5, 0.0),
bbox_transform=plt.gcf().transFigure,
fancybox=True, shadow=True, ncol=3)
plt.savefig(
os.path.join(self.results_path,
'%s - %s.svg' % (database_name, test_name)),
bbox_extra_artists=(legend,), bbox_inches='tight',
)
python类FuncFormatter()的实例源码
def read_data_for_battery_plot(self):
bd = BatteryDriver()
data = bd.get_history_charge()
x = []
y = []
for element in data:
x.append(element[0])
y.append(element[1])
self.ax.cla()
self.ax.set_xlim(min(x), max(x))
self.ax.set_ylim(-10, 110)
self.ax.grid(True)
def format_date(x, pos=None):
ltime = time.localtime(x)
return time.strftime('%H:%M', ltime)
self.ax.xaxis.set_major_formatter(
ticker.FuncFormatter(format_date))
self.fig.autofmt_xdate()
self.ax.plot(x, y)
self.fig.canvas.draw()
return True
def __init__(self, infile, outfile, analysis_type, plot_format,
plot_title, src_reverse, debug):
self._infile = infile
self._outfile = outfile
self._analysis_type = analysis_type
self._plot_format = plot_format
self._plot_title = plot_title
self._src_reverse = src_reverse
self._debug = debug
milli = 1e-3
self._format_milli = ticker.FuncFormatter(
lambda y, pos: '{0:g}'.format(y / milli))
kilo = 1e+3
self._format_kilo = ticker.FuncFormatter(
lambda y, pos: '{0:g}'.format(y / kilo))
mega = 1e+6
self._format_mega = ticker.FuncFormatter(
lambda y, pos: '{0:g}'.format(y / mega))
cent = 100
self._format_percent = ticker.FuncFormatter(
lambda y, pos: '{0:g}'.format(y * cent))
def limite_central2():
N=5000
k = 1.99999999
r=evalua(k, N)
np.random.shuffle(r)
epsilon = .1
x1 = zeros(N)
mu = 0
for i in range(N):
np.random.shuffle(r)
x1[i] = sum(r[:i]) / (i+1)
plt.hist(x1, bins=1000, range=(mu - epsilon, mu + epsilon), normed=True)
formatter = FuncFormatter(to_percent)
plt.gca().yaxis.set_major_formatter(formatter)
def plot(self, title='Rating Curve', log=True):
""" plot the rating curve """
fig = plt.figure()
ax1 = fig.add_subplot(111, facecolor=[.95,.95,.95])
plt.grid(True, which='both', color='w', ls='-', zorder=0)
ax1.scatter(self.stage, self.discharge, color='k', s=10)
ax1.set_ylabel(r'Discharge, cfs')
ax1.set_xlabel(r'Stage, ft')
if log:
ax1.set_ylim(0.01, 100)
ax1.set_yscale('log'); ax1.set_xscale('log') # log scale x and y
ax1.yaxis.set_major_formatter(FuncFormatter(lambda y,pos: ('{{:.{:1d}f}}'.format(int(np.maximum(-np.log10(y),0)))).format(y)))
ax1.xaxis.set_major_formatter(FuncFormatter(lambda y,pos: ('{{:.{:1d}f}}'.format(int(np.maximum(-np.log10(y),0)))).format(y)))
plt.title(title)
ax1.set_axisbelow(True) # puts grid below plot
# write the equation in the plot
ax1.text(0.05, 0.7, f'y = {self.popt[0]:.3f}x^{self.popt[1]:.3f}',
fontsize=15, transform=ax1.transAxes)
# draw the model line
line = np.linspace(min(self.stage), max(self.stage), 100)
ax1.plot(line, exp_curve(line, self.popt[0], self.popt[1]), color='k')
plt.show()
def create_example_s_curve_plot(self):
# Initialize plot
fig, ax = plt.subplots(figsize=(8, 4))
# Plot example S-response curve
x = np.arange(0, 20100, 100)
y = self.logistic_function(x, L=10000, k=0.0007, x_0=10000)
ax.plot(x, y, '-', label="Radio")
# Set plot options and show plot
ax.legend(loc='right')
plt.xlim([0, 20000])
plt.xlabel('Radio spend in euros')
plt.ylabel('Additional sales')
plt.title('Example of S-shaped response curve')
plt.tight_layout()
plt.grid()
ax.get_xaxis().set_major_formatter(tkr.FuncFormatter(lambda x, p: format(int(x), ',')))
ax.get_yaxis().set_major_formatter(tkr.FuncFormatter(lambda x, p: format(int(x), ',')))
plt.show()
def _finalizeFigure(fig, ax, outFile=None, yFormat=None, sideLabel=False,
labelColor=None, transparent=False, openFile=False, closeFig=True):
if yFormat:
func = (lambda x, p: format(int(x), ',')) if yFormat == ',' else (lambda x, p: yFormat % x)
formatter = FuncFormatter(func)
ax.get_yaxis().set_major_formatter(formatter)
if sideLabel:
labelColor = labelColor or 'lightgrey'
# add the filename down the right side of the plot
fig.text(1, 0.5, sideLabel, color=labelColor, weight='ultralight', fontsize=7,
va='center', ha='right', rotation=270)
if outFile:
fig.savefig(outFile, bbox_inches='tight', transparent=transparent)
if closeFig:
plt.close(fig)
if openFile:
systemOpenFile(outFile)
def plot_pore_yield_hist():
# Close any previous plots
plt.close('all')
num_bins = 50
new_yield_data = ALL_READS.groupby(["channel", "mux"])['seq_length'].sum()
fig, ax = plt.subplots(1)
(n, bins, patches) = ax.hist(new_yield_data, num_bins, weights=None,
# [1],#channels_by_yield_df['seq_length'],
normed=1, facecolor='blue', alpha=0.76)
ax.xaxis.set_major_formatter(FuncFormatter(x_hist_to_human_readable))
def y_muxhist_to_human_readable(y, position):
# Get numbers of reads per bin in the histogram
s = humanfriendly.format_size((bins[1]-bins[0])*y*new_yield_data.count(), binary=False)
return reformat_human_friendly(s)
ax.yaxis.set_major_formatter(FuncFormatter(y_muxhist_to_human_readable))
# Set the titles and axis labels
ax.set_title(f"Yield by pore {SAMPLE_NAME}")
ax.grid(color='black', linestyle=':', linewidth=0.5)
ax.set_xlabel("Yield in single pore")
ax.set_ylabel("Pores per bin")
# Ensure labels are not missed.
fig.tight_layout()
savefig(os.path.join(PLOTS_DIR, f"{SAMPLE_NAME.replace(' ', '_')}_hist_yield_by_pore.png"))
def create_slice(self, context):
""" :type context: dict """
model = self._model
axes = self._image.axes
""" :type: matplotlib.axes.Axes """
axes.set_title(model.title, fontsize=12)
axes.tick_params(axis='both')
axes.set_ylabel(model.y_axis_name, fontsize=9)
axes.set_xlabel(model.x_axis_name, fontsize=9)
axes.get_xaxis().set_major_formatter(FuncFormatter(model.x_axis_formatter))
axes.get_xaxis().set_major_locator(AutoLocator())
axes.get_yaxis().set_major_formatter(FuncFormatter(model.y_axis_formatter))
axes.get_yaxis().set_major_locator(AutoLocator())
for label in (axes.get_xticklabels() + axes.get_yticklabels()):
label.set_fontsize(9)
self._reset_zoom()
axes.add_patch(self._vertical_indicator)
axes.add_patch(self._horizontal_indicator)
self._update_indicators(context)
self._image.set_cmap(cmap=context['colormap'])
self._view_limits = context["view_limits"][self._model.index_direction['name']]
if model.data is not None:
self._image.set_data(model.data)
def pct_format():
'''Apply "to_percent" custom format for chart tick labels
'''
return ticker.FuncFormatter(to_percent)
def plotYearly(dictframe, ax, uncertainty, color='#0072B2'):
if ax is None:
figY = plt.figure(facecolor='w', figsize=(10, 6))
ax = figY.add_subplot(111)
else:
figY = ax.get_figure()
##
# Find the max index for an entry of each month
##
months = dictframe.ds.dt.month
ind = []
for month in range(1,13):
ind.append(max(months[months == month].index.tolist()))
##
# Plot from the minimum of those maximums on (this will almost certainly result in only 1 year plotted)
##
ax.plot(dictframe['ds'][min(ind):], dictframe['yearly'][min(ind):], ls='-', c=color)
if uncertainty:
ax.fill_between(dictframe['ds'].values[min(ind):], dictframe['yearly_lower'][min(ind):], dictframe['yearly_upper'][min(ind):], color=color, alpha=0.2)
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
ax.xaxis.set_major_formatter(FuncFormatter(
lambda x, pos=None: '{dt:%B} {dt.day}'.format(dt=num2date(x))))
ax.xaxis.set_major_locator(months)
ax.set_xlabel('Day of year')
ax.set_ylabel('yearly')
figY.tight_layout()
return figY
def __init__(self, dark):
self.figure = Figure(figsize=(0, 1000), dpi=75, facecolor='w', edgecolor='k')
self.axes = self.figure.add_axes([0.12, 0.08, 0.75, 0.90])
self.figure.patch.set_alpha(0)
self.axes.margins(0, 0.05)
self.axes.ticklabel_format(useOffset=False)
self.axes.xaxis.set_major_locator(MultipleLocatorWithMargin(600, 0, 0.03))
self.axes.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: "{}m".format(int(x/60))))
if dark:
self.axes.patch.set_facecolor('black')
FigureCanvas.__init__(self, self.figure)
self.set_size_request(400, 300)
self.lines = {}
self.texts = {}
def plot_ohlcv(self, df):
fig, ax = plt.subplots()
# Plot the candlestick
candlestick2_ohlc(ax, df['open'], df['high'], df['low'], df['close'],
width=1, colorup='g', colordown='r', alpha=0.5)
# shift y-limits of the candlestick plot so that there is space
# at the bottom for the volume bar chart
pad = 0.25
yl = ax.get_ylim()
ax.set_ylim(yl[0] - (yl[1] - yl[0]) * pad, yl[1])
# Add a seconds axis for the volume overlay
ax2 = ax.twinx()
ax2.set_position(
matplotlib.transforms.Bbox([[0.125, 0.1], [0.9, 0.26]]))
# Plot the volume overlay
# bc = volume_overlay(ax2, df['open'], df['close'], df['volume'],
# colorup='g', alpha=0.5, width=1)
ax.xaxis.set_major_locator(ticker.MaxNLocator(6))
def mydate(x, pos):
try:
return df.index[int(x)]
except IndexError:
return ''
ax.xaxis.set_major_formatter(ticker.FuncFormatter(mydate))
plt.margins(0)
plt.show()
def plot_index_and_sentiment(tick_seq, shindex_seq, sentiment_seq, date):
if len(tick_seq) != len(shindex_seq) or len(tick_seq) != len(sentiment_seq):
print('error(plot) : three sequence length is not same')
return
x = range(len(shindex_seq))
labels = tick_seq
y1 = shindex_seq
y2 = sentiment_seq
def format_fn(tick_val, tick_pos):
if int(tick_val) in x:
return labels[int(tick_val)]
else:
return ''
fig = plt.figure(figsize=(12,8))
p1 = fig.add_subplot(111)
p1.xaxis.set_major_formatter(FuncFormatter(format_fn))
p1.xaxis.set_major_locator(MaxNLocator(integer=True, nbins=12))
delta = shindex_seq[len(shindex_seq) - 1] - shindex_seq[0]
if delta > 0:
p1.plot(x, y1, label="$SCI$", color="red", linewidth=1)
else:
p1.plot(x, y1, label="$SCI$", color="green", linewidth=1)
p1.plot(x, y2, 'b--', label="$ISI$", color="blue", linewidth=1)
plt.title("Shanghai Composite Index(SCI) & Investor Sentiment Index(ISI)")
plt.xlabel("Time(5min)")
plt.ylabel("Index Value")
plt.legend()
# plt.show()
global subdir
filepath = './Pic/' + subdir + '/' + date + '.png'
plt.savefig(filepath)
def limite_central():
N=5000
epsilon = 5e-2
x1 = zeros(N)
mu = 0.5
for i in range(N):
x1[i] = sum(rand(i+1)) / (i+1)
plt.hist(x1, bins=100, range=(mu - epsilon, mu + epsilon), normed=True)
formatter = FuncFormatter(to_percent)
plt.gca().yaxis.set_major_formatter(formatter)
#limite_central()
def _set_integer_tick_labels(axis, labels):
"""Use labels dict to set labels on axis"""
axis.set_major_formatter(FuncFormatter(lambda x, _: labels.get(x, '')))
axis.set_major_locator(MaxNLocator(integer=True))
def plot_yearly(self, ax=None, uncertainty=True, yearly_start=0):
"""Plot the yearly component of the forecast.
Parameters
----------
ax: Optional matplotlib Axes to plot on. One will be created if
this is not provided.
uncertainty: Optional boolean to plot uncertainty intervals.
yearly_start: Optional int specifying the start day of the yearly
seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts
by 1 day to Jan 2, and so on.
Returns
-------
a list of matplotlib artists
"""
artists = []
if not ax:
fig = plt.figure(facecolor='w', figsize=(10, 6))
ax = fig.add_subplot(111)
# Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
days = (pd.date_range(start='2017-01-01', periods=365) +
pd.Timedelta(days=yearly_start))
df_y = self.seasonality_plot_df(days)
seas = self.predict_seasonal_components(df_y)
artists += ax.plot(
df_y['ds'].dt.to_pydatetime(), seas['yearly'], ls='-', c='#0072B2')
if uncertainty:
artists += [ax.fill_between(
df_y['ds'].dt.to_pydatetime(), seas['yearly_lower'],
seas['yearly_upper'], color='#0072B2', alpha=0.2)]
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
ax.xaxis.set_major_formatter(FuncFormatter(
lambda x, pos=None: '{dt:%B} {dt.day}'.format(dt=num2date(x))))
ax.xaxis.set_major_locator(months)
ax.set_xlabel('Day of year')
ax.set_ylabel('yearly')
return artists
def plotDatePrice(productID, productTitle, data):
# Data setup
x, y = [], []
for datapoint in data:
date = datapoint.split('|')[0]
price = float(datapoint.split('|')[1])
x.append(dt.datetime.strptime(date, '%Y-%m-%d'))
y.append(price)
x = matplotlib.dates.date2num(x)
x_np, y_np = np.array(x), np.array(y)
# Plot setup
ax = plt.figure(figsize=(6, 3)).add_subplot(111)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
ax.plot(x_np, y_np, color='lightblue', lw=2)
ax.margins(0.05)
ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: ('$%i' % (x))))
ax.xaxis.set_major_formatter(DateFormatter('%Y-%m-%d'))
plt.yticks(fontsize=8)
plt.ylim(ymin=min(y)*0.7, ymax=max(y)*1.3)
plt.title('Recent Price History\n'+productTitle, weight ='light', fontsize=12, y=1.08)
plt.xticks(rotation=40, fontsize=7)
plt.tight_layout()
plt.savefig(productID+'.png')
return productID+'.png'
# ----- Email Configuration ----------------------------------------------------
def ticklabels_to_percent(ax, axis='y'):
getattr(ax, '{}axis'.format(axis)).set_major_formatter(
mticker.FuncFormatter(lambda s, position: '{:0.2%}'.format(s)))
return ax
def ticklabels_to_thousands_sep(ax, axis='y'):
getattr(ax, '{}axis'.format(axis)).set_major_formatter(
mticker.FuncFormatter(lambda s, position: '{:,}'.format(int(s))))
return ax
def plot(self, addseries=[], log=True, title='Discharge'):
"""
Quick plot with or without rain data.\n
If you wish to plot more than one series to compare them, use addseries
to list in order of [time, Q, ...] for each additional series.
"""
fig = plt.figure()
ax1 = fig.add_subplot(111, facecolor=[.95,.95,.95])
plt.grid(True, which='both', color='w', ls='-', zorder=0)
ax1.plot(self.time, self.Q, label='Series1')
if len(self.rain) != 0:
ax2 = ax1.twinx()
ax2.plot(self.time, self.rain, alpha=.5, c='b', lw=1, label='Rain')
ax2.set_ylim(1, 0)
ax2.set_ylabel(r'Rain, in')
ax1.set_ylabel('Discharge, cfs')
ax1.set_xlabel('Stage, ft')
# log scale for y axis
if log:
ax1.set_yscale('log')
ax1.yaxis.set_major_formatter(FuncFormatter(lambda y,pos: ('{{:.{:1d}f}}'.format(int(np.maximum(-np.log10(y),0)))).format(y)))
# add ablity to plot multiple time series
more = len(addseries)
while more > 0:
ax1.plot(addseries[more-2], addseries[more-1],
label=f'Series{int(len(addseries)/2-more/2 +2)}')
more -= 2
ax1.legend(loc='best')
plt.title(title)
plt.show()
def plot(outfn, a, genomeSize, base2chr, _windowSize, dpi=300, ext="svg"):
"""Save contact plot"""
def format_fn(tick_val, tick_pos):
"""Mark axis ticks with chromosome names"""
if int(tick_val) in base2chr:
return base2chr[int(tick_val)]
else:
sys.stderr.write("[WARNING] %s not in ticks!\n"%tick_val)
return ''
# invert base2chr
base2chr = {genomeSize-b: c for b, c in base2chr.iteritems()}
# start figure
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_title("Contact intensity plot [%sk]"%(_windowSize/1000,))
# label Y axis with chromosome names
if len(base2chr)<50:
ax.yaxis.set_major_formatter(FuncFormatter(format_fn))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
plt.yticks(base2chr.keys())
ax.set_ylabel("Chromosomes")
else:
ax.set_ylabel("Genome position")
# label axes
ax.set_xlabel("Genome position")
plt.imshow(a+1, cmap=cm.hot, norm=LogNorm(), extent=(0, genomeSize, 0, genomeSize))#
plt.colorbar()
# save
fig.savefig("%s.%s"%(outfn,ext), dpi=dpi, papertype="a4")
def plot(outfn, a, genomeSize, base2chr, _windowSize, dpi=300, ext="svg"):
"""Save contact plot"""
def format_fn(tick_val, tick_pos):
"""Mark axis ticks with chromosome names"""
if int(tick_val) in base2chr:
return base2chr[int(tick_val)]
else:
sys.stderr.write("[WARNING] %s not in ticks!\n"%tick_val)
return ''
# invert base2chr
base2chr = {genomeSize-b: c for b, c in base2chr.iteritems()}
# start figure
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_title("Contact intensity plot [%sk]"%(_windowSize/1000,))
# label Y axis with chromosome names
if len(base2chr)<50:
ax.yaxis.set_major_formatter(FuncFormatter(format_fn))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
plt.yticks(base2chr.keys())
ax.set_ylabel("Chromosomes")
else:
ax.set_ylabel("Genome position")
# label axes
ax.set_xlabel("Genome position")
plt.imshow(a+1, cmap=cm.hot, norm=LogNorm(), extent=(0, genomeSize, 0, genomeSize))#
plt.colorbar()
# save
fig.savefig("%s.%s"%(outfn,ext), dpi=dpi, papertype="a4")
def pimp_axis(x_or_y_ax):
"""Remove trailing zeros.
"""
x_or_y_ax.set_major_formatter(ticker.FuncFormatter(ticks_formatter))
def plot_yield_general():
# Close any previous plots
plt.close('all')
# Set subplots.
fig, ax = plt.subplots(1)
# Create ticks using numpy linspace. Ideally will create 6 points between 0 and 48 hours.
num_points = 7 # Need to include zero point.
x_ticks = np.linspace(YIELD_DATA['duration_float'].min(), YIELD_DATA['duration_float'].max(), num_points)
ax.set_xticks(x_ticks)
# Define axis formatters
ax.yaxis.set_major_formatter(FuncFormatter(y_yield_to_human_readable))
ax.xaxis.set_major_formatter(FuncFormatter(x_yield_to_human_readable))
# Set x and y labels and title
ax.set_xlabel("Duration (HH:MM)")
ax.set_ylabel("Yield")
ax.set_title(f"Yield for {SAMPLE_NAME} over time")
# Produce plot
ax.plot(YIELD_DATA['duration_float'], YIELD_DATA['cumsum_bp'],
linestyle="solid", markevery=[])
# Limits must be set after the plot is created
ax.set_xlim(YIELD_DATA['duration_float'].min(), YIELD_DATA['duration_float'].max())
ax.set_ylim(ymin=0)
# Ensure labels are not missed.
fig.tight_layout()
savefig(os.path.join(PLOTS_DIR, f"{SAMPLE_NAME.replace(' ', '_')}_yield_plot.png"))
def plot_read_length_hist():
# Close any previous plots
plt.close('all')
num_bins = 50
seq_df = ALL_READS["seq_length"]
if CLIP:
# Filter out the top 1000th percentile.
seq_df = seq_df[seq_df < seq_df.quantile(0.9995)]
def y_hist_to_human_readable_seq(y, position):
# Convert distribution to base pairs
if y == 0:
return 0
s = humanfriendly.format_size(seq_df.sum() * y, binary=False)
return reformat_human_friendly(s)
# Define how many plots we want (1)
fig, ax = plt.subplots(1)
# Set the axis formatters
ax.yaxis.set_major_formatter(FuncFormatter(y_hist_to_human_readable_seq))
ax.xaxis.set_major_formatter(FuncFormatter(x_hist_to_human_readable))
# Plot the histogram
h, w, p = ax.hist(seq_df, num_bins, weights=seq_df,
normed=1, facecolor='blue', alpha=0.76)
bin_width = reformat_human_friendly(humanfriendly.format_size(w[1]-w[0], binary=False))
# Set the titles and axis labels
ax.set_title(f"Read Distribution Graph for {SAMPLE_NAME}")
ax.grid(color='black', linestyle=':', linewidth=0.5)
ax.set_xlabel(f"Read length: Bin Widths={bin_width}")
ax.set_ylabel("Bases per bin")
# Ensure labels are not missed.
fig.tight_layout()
savefig(os.path.join(PLOTS_DIR, f"{SAMPLE_NAME.replace(' ', '_')}_hist_read_length_by_basepair.png"))
def plot_macd(df):
macd, macdsignal, macdhist = macd_data(df)
fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot(111)
x = np.arange(len(df))
ax.grid(True)
# ax.plot(x, df['close'], '.-', label=u'close')
# ax.hist(macd)
ax.plot(x, macdsignal, 'r-', label=u'macdsignal')
ax.plot(x, macdhist, 'k-', label=u'macdhist')
ax.legend(loc='best')
ax.set_xlabel(str(df[u'date'][:].year))
# ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_date))
plt.show()
def make_probes_ba_traj_fig(models1, models2=None, palette=None): # TODO ylim
"""
Returns fig showing trajectory of probes balanced accuracy
"""
start = time.time()
sns.set_style('white')
# load data
xys = []
model_groups = [models1] if models2 is None else [models1, models2]
for n, models in enumerate(model_groups):
model_probes_ba_trajs = []
for nn, model in enumerate(models):
model_probes_ba_trajs.append(model.get_traj('probes_ba'))
x = models[0].get_data_step_axis()
traj_mat = np.asarray([traj[:len(x)] for traj in model_probes_ba_trajs]) # all trajs are truncated to shortest
y = np.mean(traj_mat, axis=0)
sem = [stats.sem(model_probes_bas) for model_probes_bas in traj_mat.T]
xys.append((x, y, sem))
# fig
fig, ax = plt.subplots(figsize=(FigsConfigs.MAX_FIG_WIDTH, 3))
ax.set_ylim([50, 75])
ax.set_xlabel('Mini Batch', fontsize=FigsConfigs.AXLABEL_FONT_SIZE)
ax.set_ylabel('Probes Balanced Accuracy', fontsize=FigsConfigs.AXLABEL_FONT_SIZE)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.tick_params(axis='both', which='both', top='off', right='off')
ax.xaxis.set_major_formatter(FuncFormatter(human_format))
ax.yaxis.grid(True)
# plot
for (x, y, sem) in xys:
color = next(palette) if palette is not None else 'black'
ax.plot(x, y, '-', linewidth=FigsConfigs.LINEWIDTH, color=color)
ax.fill_between(x, np.add(y, sem), np.subtract(y, sem), alpha=FigsConfigs.FILL_ALPHA, color='grey')
plt.tight_layout()
print('{} completed in {:.1f} secs'.format(sys._getframe().f_code.co_name, time.time() - start))
return fig
def make_probes_pp_traj_fig(models1, models2=None, palette=None):
"""
Returns fig showing trajectory of Probes Perplexity
"""
start = time.time()
sns.set_style('white')
# load data
xys = []
model_groups = [models1] if models2 is None else [models1, models2]
for n, models in enumerate(model_groups):
probes_pp_trajs_w = []
probes_pp_trajs_uw = []
for nn, model in enumerate(models):
probes_pp_trajs_w.append(model.get_traj('probes_pp'))
probes_pp_trajs_uw.append(model.get_traj('probes_pp_uw'))
x = models[0].get_data_step_axis()
traj_mat1 = np.asarray([traj[:len(x)] for traj in probes_pp_trajs_w])
traj_mat2 = np.asarray([traj[:len(x)] for traj in probes_pp_trajs_uw])
y1 = np.mean(traj_mat1, axis=0)
y2 = np.mean(traj_mat2, axis=0)
xys.append((x, y1, y2))
# fig
fig, ax = plt.subplots(figsize=(FigsConfigs.MAX_FIG_WIDTH, 3))
ylabel = 'Probes Perplexity'
ax.set_ylabel(ylabel, fontsize=FigsConfigs.AXLABEL_FONT_SIZE)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.tick_params(axis='both', which='both', top='off', right='off')
ax.set_xlabel('Mini Batch', fontsize=FigsConfigs.AXLABEL_FONT_SIZE)
ax.xaxis.set_major_formatter(FuncFormatter(human_format))
ax.yaxis.grid(True)
# plot
for (x, y1, y2) in xys:
color = next(palette) if palette is not None else 'black'
ax.plot(x, y1, '-', linewidth=FigsConfigs.LINEWIDTH, color=color, linestyle='-', label='weighted')
ax.plot(x, y2, '-', linewidth=FigsConfigs.LINEWIDTH, color=color, linestyle='--', label='unweighted')
plt.legend(loc='best')
plt.tight_layout()
print('{} completed in {:.1f} secs'.format(sys._getframe().f_code.co_name, time.time() - start))
return fig
def make_avg_traj_figs(model):
def make_avg_traj_fig(traj_name):
"""
Returns fig showing trajectory of Probes Perplexity
"""
start = time.time()
sns.set_style('white')
ylims = model.eval_name_range_dict[traj_name]
# load data
x = model.get_data_step_axis()
y = model.get_traj(traj_name)
# fig
fig, ax = plt.subplots(figsize=(FigsConfigs.MAX_FIG_WIDTH, 3), dpi=FigsConfigs.DPI)
ax.set_ylim(ylims)
ax.set_xlabel('Mini Batch', fontsize=FigsConfigs.AXLABEL_FONT_SIZE)
ax.set_ylabel(traj_name, fontsize=FigsConfigs.AXLABEL_FONT_SIZE)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.tick_params(axis='both', which='both', top='off', right='off')
ax.xaxis.set_major_formatter(FuncFormatter(human_format))
ax.yaxis.grid(True)
# plot
ax.plot(x, y, '-', linewidth=FigsConfigs.LINEWIDTH, color='black')
plt.tight_layout()
print('{} completed in {:.1f} secs'.format(sys._getframe().f_code.co_name, time.time() - start))
return fig
figs = [make_avg_traj_fig(traj_name) for traj_name in AppConfigs.EVAL_NAMES]
return figs