def _forward_rnn(cell, input_, length, hx):
max_time = input_.size(0)
output = []
for time in range(max_time):
h_next = cell(input_=input_[time], hx=hx)
# mask = (time < length).float().unsqueeze(1).expand_as(h_next)
# h_next = h_next*mask + hx*(1 - mask)
output.append(h_next)
output = torch.stack(output, 0)
return output, h_next
评论列表
文章目录