def forward(self, inputs, hidden):
def select_layer(h_state, i): # To work on both LSTM / GRU, RNN
if isinstance(h_state, tuple):
return tuple([select_layer(s, i) for s in h_state])
else:
return h_state[i]
next_hidden = []
for i, layer in enumerate(self.layers):
hidden_i = select_layer(hidden, i)
next_hidden_i = layer(inputs, hidden_i)
output = next_hidden_i[0] if isinstance(next_hidden_i, tuple) \
else next_hidden_i
if i + 1 != self.num_layers:
output = self.dropout(output)
if i > 0 and self.residual:
inputs = output + inputs
else:
inputs = output
next_hidden.append(next_hidden_i)
if isinstance(hidden, tuple):
next_hidden = tuple([torch.stack(h) for h in zip(*next_hidden)])
else:
next_hidden = torch.stack(next_hidden)
return inputs, next_hidden
评论列表
文章目录