def get_heatmaps(neuron_list, spikes, pos, num_bins=100):
""" Gets the 2D heatmaps for firing of a given set of neurons.
Parameters
----------
neuron_list : list of ints
These will be the indices into the full list of neuron spike times
spikes : list
Containing nept.SpikeTrain for each neuron.
pos : nept.Position
Must be 2D.
num_bins : int
This will specify how the 2D space is broken up, the greater the number
the more specific the heatmap will be. The default is set at 100.
Returns
-------
heatmaps : dict of lists
Where the key is the neuron number and the value is the heatmap for
that individual neuron.
"""
if not pos.dimensions == 2:
raise ValueError("pos must be two-dimensional")
xedges = np.linspace(np.min(pos.x)-2, np.max(pos.x)+2, num_bins+1)
yedges = np.linspace(np.min(pos.y)-2, np.max(pos.y)+2, num_bins+1)
heatmaps = dict()
count = 1
for neuron in neuron_list:
field_x = []
field_y = []
for spike in spikes[neuron].time:
spike_idx = find_nearest_idx(pos.time, spike)
field_x.append(pos.x[spike_idx])
field_y.append(pos.y[spike_idx])
heatmap, out_xedges, out_yedges = np.histogram2d(field_x, field_y, bins=[xedges, yedges])
heatmaps[neuron] = heatmap.T
print(str(neuron) + ' of ' + str(len(neuron_list)))
count += 1
return heatmaps
评论列表
文章目录