model.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:rnnlab 作者: phueb 项目源码 文件源码
def get_trajs_mat(self, cols, traj):
        if traj == 'avg_probe_pp':
            with pd.HDFStore(self.pp_traj_df_path, mode='r') as store:
                df_traj = store.select('pp_traj_df', columns=cols)
                trajs_mat = df_traj.values.transpose()
        elif traj == 'avg_probe_ba':
            with pd.HDFStore(self.ba_traj_df_path, mode='r') as store:
                df_traj = store.select('ba_traj_df', columns=cols)
                trajs_mat = df_traj.values.transpose()
        elif 'cat_task' in traj:
            with pd.HDFStore(self.cat_task_traj_df_path, mode='r') as store:
                columns = [traj.replace('cat_task_', '') + '_fold{}'.format(i) for i in cols]
                df_traj = store.select('cat_task_traj_df', columns=columns)
                trajs_mat = df_traj.values.transpose()
        elif 'syn_task' in traj:
            with pd.HDFStore(self.syn_task_traj_df_path, mode='r') as store:
                columns = [traj.replace('syn_task_', '') + '_fold{}'.format(i) for i in cols]
                df_traj = store.select('syn_task_traj_df', columns=columns)
                trajs_mat = df_traj.values.transpose()
        else:
            raise AttributeError('rnnlab: Invalid argument passed to "traj".')
        return trajs_mat
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号