def forward(self, input, c0=None, return_hidden=True):
assert input.dim() == 3 # (len, batch, n_in)
dir_ = 2 if self.bidirectional else 1
if c0 is None:
zeros = Variable(input.data.new(
input.size(1), self.n_out*dir_
).zero_())
c0 = [ zeros for i in range(self.depth) ]
else:
assert c0.dim() == 3 # (depth, batch, n_out*dir_)
c0 = [ x.squeeze(0) for x in c0.chunk(self.depth, 0) ]
prevx = input
lstc = []
for i, rnn in enumerate(self.rnn_lst):
h, c = rnn(prevx, c0[i])
prevx = h
lstc.append(c)
if return_hidden:
return prevx, torch.stack(lstc)
else:
return prevx
评论列表
文章目录