def summarizePerformance(self, test_data_set):
"""
This function is called at every PERIOD_BTW_SUMMARY_PERFS.
Parameters
-----------
test_data_set
"""
print ("Summary Perf")
observations = test_data_set.observations()
prices = observations[0][100:200]
invest = observations[1][100:200]
steps=np.arange(len(prices))
steps_long=np.arange(len(prices)*10)/10.
#print steps,invest,prices
host = host_subplot(111, axes_class=AA.Axes)
plt.subplots_adjust(right=0.9, left=0.1)
par1 = host.twinx()
host.set_xlabel("Time")
host.set_ylabel("Price")
par1.set_ylabel("Investment")
p1, = host.plot(steps_long, np.repeat(prices,10), lw=3, c = 'b', alpha=0.8, ls='-', label = 'Price')
p2, = par1.plot(steps, invest, marker='o', lw=3, c = 'g', alpha=0.5, ls='-', label = 'Investment')
par1.set_ylim(-0.09, 1.09)
host.axis["left"].label.set_color(p1.get_color())
par1.axis["right"].label.set_color(p2.get_color())
plt.savefig("plot.png")
print ("A plot of the policy obtained has been saved under the name plot.png")
python类host_subplot()的实例源码
def plot_machine(self):
class_instance = file_handler(self.filename)
class_instance.file_iteration()
data_sets = class_instance.data_conversion()
names = getattr(class_instance, "substances")
if len(names) > 2:
host = host_subplot(111, axes_class = AA.Axes)
plt.subplots_adjust(right = 0.75)
par1 = host.twinx()
par2 = host.twinx()
host.set_yscale("log")
par1.set_yscale("log")
par2.set_yscale("log")
offset = 60
new_fixed_axis = par2.get_grid_helper().new_fixed_axis
par2.axis["right"] = new_fixed_axis(loc="right",
axes=par2,
offset=(offset, 0))
par2.axis["right"].toggle(all = True)
host.set_xlabel(data_sets[0]["x_unit"])
plotty_things = [host, par1, par2]
for data_set, name, things in zip(data_sets, names, plotty_things):
x_val = data_set["data"][0]
y_val = data_set["data"][1]
x_unit = data_set["x_unit"]
y_unit = data_set["y_unit"]
things.set_ylabel(y_unit)
things.plot(x_val, y_val, label = data_set["sample element"])
plt.legend()
plt.show()
else:
data_set = data_sets[0]
x_val = data_set["data"][0]
y_val = data_set["data"][1]
x_val = x_val.copy(order = "C")
x_unit = data_set["x_unit"]
y_unit = data_set["y_unit"]
plt.semilogy(x_val, y_val, label = data_set["sample info"][2], nonposy = "clip")
plt.xlabel(x_unit)
plt.ylabel(y_unit)
plt.legend()
plt.show()
def plot_machine(self):
class_instance = file_handler(self.filename)
class_instance.file_iteration()
data_sets = class_instance.data_conversion()
names = getattr(class_instance, "substances")
if len(names) > 2:
host = host_subplot(111, axes_class = AA.Axes)
plt.subplots_adjust(right = 0.75)
par1 = host.twinx()
par2 = host.twinx()
host.set_yscale("log")
par1.set_yscale("log")
par2.set_yscale("log")
offset = 60
new_fixed_axis = par2.get_grid_helper().new_fixed_axis
par2.axis["right"] = new_fixed_axis(loc="right",
axes=par2,
offset=(offset, 0))
par2.axis["right"].toggle(all = True)
host.set_xlabel(data_sets[0]["x_unit"])
plotty_things = [host, par1, par2]
for data_set, name, things in zip(data_sets, names, plotty_things):
x_val = data_set["data"]["x"]
y_val = data_set["data"]["y"]
x_unit = data_set["x_unit"]
y_unit = data_set["y_unit"]
things.set_ylabel(y_unit)
things.plot(x_val, y_val, label = data_set["sample element"])
plt.legend()
plt.show()
else:
data_set = data_sets[0]
x_val = data_set["data"][0]
y_val = data_set["data"][1]
x_val = x_val.copy(order = "C")
x_unit = data_set["x_unit"]
y_unit = data_set["y_unit"]
plt.semilogy(x_val, y_val, label = data_set["sample info"][2], nonposy = "clip")
plt.xlabel(x_unit)
plt.ylabel(y_unit)
plt.legend()
plt.show()