def get_trajs_mat(self, cols, traj):
if traj == 'avg_probe_pp':
with pd.HDFStore(self.pp_traj_df_path, mode='r') as store:
df_traj = store.select('pp_traj_df', columns=cols)
trajs_mat = df_traj.values.transpose()
elif traj == 'avg_probe_ba':
with pd.HDFStore(self.ba_traj_df_path, mode='r') as store:
df_traj = store.select('ba_traj_df', columns=cols)
trajs_mat = df_traj.values.transpose()
elif 'cat_task' in traj:
with pd.HDFStore(self.cat_task_traj_df_path, mode='r') as store:
columns = [traj.replace('cat_task_', '') + '_fold{}'.format(i) for i in cols]
df_traj = store.select('cat_task_traj_df', columns=columns)
trajs_mat = df_traj.values.transpose()
elif 'syn_task' in traj:
with pd.HDFStore(self.syn_task_traj_df_path, mode='r') as store:
columns = [traj.replace('syn_task_', '') + '_fold{}'.format(i) for i in cols]
df_traj = store.select('syn_task_traj_df', columns=columns)
trajs_mat = df_traj.values.transpose()
else:
raise AttributeError('rnnlab: Invalid argument passed to "traj".')
return trajs_mat
评论列表
文章目录