def __merge_states(self, state_list, type_state='hidden'):
if state_list is None:
return None
if isinstance(state_list[0], State):
return State().from_list(state_list)
if isinstance(state_list[0], tuple):
return tuple([self.__merge_states(s, type_state) for s in zip(*state_list)])
else:
if isinstance(state_list[0], Variable) or torch.is_tensor(state_list[0]):
if type_state == 'hidden':
batch_dim = 0 if state_list[0].dim() < 3 else 1
else:
batch_dim = 0 if self.batch_first else 1
return torch.cat(state_list, batch_dim)
else:
assert state_list[1:] == state_list[:-1] # all items are equal
return state_list[0]
评论列表
文章目录