def _forward_rnn(cell, input_, length, hx):
max_time = input_.size(0)
output = []
for time in range(max_time):
if isinstance(cell, BNLSTMCell):
h_next, c_next = cell(input_=input_[time], hx=hx, time=time)
else:
h_next, c_next = cell(input_=input_[time], hx=hx)
mask = (time < length).float().unsqueeze(1).expand_as(h_next)
h_next = h_next*mask + hx[0]*(1 - mask)
c_next = c_next*mask + hx[1]*(1 - mask)
hx_next = (h_next, c_next)
output.append(h_next)
hx = hx_next
output = torch.stack(output, 0)
return output, hx
评论列表
文章目录