def plot_spikes(in_file, in_fft, spikes_list, cols=3,
labelfmt='t={0:.3f}s (z={1:d})',
out_file=None):
from mpl_toolkits.axes_grid1 import make_axes_locatable
nii = nb.as_closest_canonical(nb.load(in_file))
fft = nb.load(in_fft).get_data()
data = nii.get_data()
zooms = nii.header.get_zooms()[:2]
tstep = nii.header.get_zooms()[-1]
ntpoints = data.shape[-1]
if len(spikes_list) > cols * 7:
cols += 1
nspikes = len(spikes_list)
rows = 1
if nspikes > cols:
rows = math.ceil(nspikes / cols)
fig = plt.figure(figsize=(7 * cols, 5 * rows))
for i, (t, z) in enumerate(spikes_list):
prev = None
pvft = None
if t > 0:
prev = data[..., z, t - 1]
pvft = fft[..., z, t - 1]
post = None
psft = None
if t < (ntpoints - 1):
post = data[..., z, t + 1]
psft = fft[..., z, t + 1]
ax1 = fig.add_subplot(rows, cols, i + 1)
divider = make_axes_locatable(ax1)
ax2 = divider.new_vertical(size="100%", pad=0.1)
fig.add_axes(ax2)
plot_slice_tern(data[..., z, t], prev=prev, post=post, spacing=zooms,
ax=ax2,
label=labelfmt.format(t * tstep, z))
plot_slice_tern(fft[..., z, t], prev=pvft, post=psft, vmin=-5, vmax=5,
cmap=get_parula(), ax=ax1)
plt.tight_layout()
if out_file is None:
fname, ext = op.splitext(op.basename(in_file))
if ext == '.gz':
fname, _ = op.splitext(fname)
out_file = op.abspath('%s.svg' % fname)
fig.savefig(out_file, format='svg', dpi=300, bbox_inches='tight')
return out_file
评论列表
文章目录