def plot_posterior_linear(params_fname, fig_fname, control=False, M=20):
# load dataset
data = np.loadtxt('./sandbox/hh_data.txt')
# use the voltage and potasisum current
data = data / np.std(data, axis=0)
y = data[:, :4]
xc = data[:, [-1]]
# init hypers
Dlatent = 2
Dobs = y.shape[1]
T = y.shape[0]
if control:
x_control = xc
no_panes = 5
else:
x_control = None
no_panes = 4
model_aep = aep.SGPSSM_Linear(y, Dlatent, M,
lik='Gaussian', prior_mean=0, prior_var=1000, x_control=x_control)
model_aep.load_model(params_fname)
my, vy, vyn = model_aep.get_posterior_y()
vy_diag = np.diagonal(vy, axis1=1, axis2=2)
vyn_diag = np.diagonal(vyn, axis1=1, axis2=2)
cs = ['k', 'r', 'b', 'g']
labels = ['V', 'm', 'n', 'h']
plt.figure()
t = np.arange(T)
for i in range(4):
yi = y[:, i]
mi = my[:, i]
vi = vy_diag[:, i]
vin = vyn_diag[:, i]
plt.subplot(no_panes, 1, i + 1)
plt.fill_between(t, mi + 2 * np.sqrt(vi), mi - 2 *
np.sqrt(vi), color=cs[i], alpha=0.4)
plt.plot(t, mi, '-', color=cs[i])
plt.plot(t, yi, '--', color=cs[i])
plt.ylabel(labels[i])
plt.xticks([])
plt.yticks([])
if control:
plt.subplot(no_panes, 1, no_panes)
plt.plot(t, x_control, '-', color='m')
plt.ylabel('I')
plt.yticks([])
plt.xlabel('t')
plt.savefig(fig_fname)
if control:
plot_model_with_control(model_aep, '', '_linear_with_control')
else:
plot_model_no_control(model_aep, '', '_linear_no_control')
评论列表
文章目录