state.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:seq2seq.pytorch 作者: eladhoffer 项目源码 文件源码
def __merge_states(self, state_list, type_state='hidden'):
        if state_list is None:
            return None
        if isinstance(state_list[0], State):
            return State().from_list(state_list)
        if isinstance(state_list[0], tuple):
            return tuple([self.__merge_states(s, type_state) for s in zip(*state_list)])
        else:
            if isinstance(state_list[0], Variable) or torch.is_tensor(state_list[0]):
                if type_state == 'hidden':
                    batch_dim = 0 if state_list[0].dim() < 3 else 1
                else:
                    batch_dim = 0 if self.batch_first else 1
                return torch.cat(state_list, batch_dim)
            else:
                assert state_list[1:] == state_list[:-1]  # all items are equal
                return state_list[0]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号