def make_mutual_info_plot(fn):
M.rcParams.update({'font.size': 11})
angleList = N.array([f/f.max() for f in read.extract_arr_from_h5(fn, "/history/angle", n=-1)])
mutualInfoList = read.extract_arr_from_h5(fn, "/history/mutual_info", n=-1)
quatList = read.extract_arr_from_h5(fn, "/history/quaternion", n=-1)
quatSwitchPos = N.where(quatList[:-1]-quatList[1:] != 0)[0] + 1
angsort = N.argsort(angleList[-1])
misort = N.argsort(mutualInfoList.mean(axis=0))
blkPositions = [0] + list(quatSwitchPos) + [-1]
for bp in range(len(blkPositions)-1):
(start, end) = (blkPositions[bp], blkPositions[bp+1])
curr_blk = angleList[start:end]
curr_blk2 = mutualInfoList[start:end] / N.log(quatSize[quatList[0]])
# curr_blk2 = mutualInfoList[start:end] / N.log(quatSize[quatList[bp]])
if len(curr_blk) == 0:
pass
else:
angsort = N.argsort(curr_blk[-1])
angleList[start:end] = curr_blk[:,angsort]
for n,l in enumerate(curr_blk2):
misort = N.argsort(l)
mutualInfoList[start+n] = l[misort]
P.ioff()
fig, ax = P.subplots(2, 1, sharex=True, figsize=(7, 10))
fig.subplots_adjust(hspace=0.1)
im0 = ax[0].imshow(angleList.transpose(), aspect='auto', interpolation=None, cmap=P.cm.OrRd)
ax[0].set_xlabel("iteration")
ax[0].set_ylabel("each pattern's most likely orientation\n(sorted by final orientation in each block)")
(e_min, e_max) = (1, len(angleList[0]))
e_int = 0.1*(e_max-e_min)
ax[0].plot([0, 0], [e_min, e_max], 'k-')
ax[0].text(1, e_max-e_int, "quat%d"%quatList[0], size=8, rotation=-0, ha='left', va='center', color='w', bbox=dict(boxstyle="larrow,pad=0.1",facecolor='0.1') )
for n,qs in enumerate(quatSwitchPos):
ax[0].plot([qs, qs], [e_min, e_max], 'k-')
ax[0].text(qs-1, e_max+(-n-1)*e_int, "quat%d"%quatList[qs], size=8, rotation=-0, ha='right', va='center', color='w', bbox=dict(boxstyle="rarrow,pad=0.1",facecolor='0.1') )
div0 = make_axes_locatable(ax[0])
cax0 = div0.append_axes("right", size="5%", pad=0.05)
cbar0 = P.colorbar(im0, cax=cax0)
ax[0].set_ylim(e_min, e_max)
ax[0].set_xlim(0, len(angleList)-1)
(e_min, e_max) = (1, len(mutualInfoList[0]))
e_int = 0.1*(e_max-e_min)
im1 = ax[1].imshow(mutualInfoList.transpose(), vmax=.2, aspect='auto', cmap=P.cm.YlGnBu)
ax[1].set_xlabel("iteration")
ax[1].set_ylabel("average mutual-information per dataset\n(sorted by average information)")
ax[1].plot([0, 0], [e_min, e_max], 'k-')
ax[1].text(1, e_max-e_int, "quat%d"%quatList[0], size=8, rotation=-0, ha='left', va='center', color='w', bbox=dict(boxstyle="larrow,pad=0.1",facecolor='0.1') )
for n,qs in enumerate(quatSwitchPos):
ax[1].plot([qs, qs], [e_min, e_max], 'k-')
ax[1].text(qs-1, e_max+(-n-1)*e_int, "quat%d"%quatList[qs], size=8, rotation=-0, ha='right', va='center', color='w', bbox=dict(boxstyle="rarrow,pad=0.1",facecolor='0.1') )
div1 = make_axes_locatable(ax[1])
cax1 = div1.append_axes("right", size="5%", pad=0.05)
cbar1 = P.colorbar(im1, cax=cax1)
ax[1].set_ylim(e_min, e_max)
ax[1].set_xlim(0, len(mutualInfoList)-1)
img_name = "mutual_info_plot.pdf"
P.savefig(img_name, bbox_inches='tight')
print("Image has been saved as %s" % img_name)
P.close(fig)
评论列表
文章目录