def _forward_rnn(cell, input_, grads_, length, hx):
max_time = input_.size(0)
output = []
for time in range(max_time):
hx = cell(input_=input_[time],grads_=grads_[time], hx=hx)
#mask = (time < length).float().unsqueeze(1).expand_as(h_next[0])
#fS_next = h_next[0] * mask + hx[0] * (1 - mask)
#iS_next = h_next[1] * mask + hx[1] * (1 - mask)
#cS_next = h_next[2] * mask + hx[2] * (1 - mask)
#deltaS_next = h_next[3] * mask + hx[3] * (1 - mask)
#hx_next = (fS_next, iS_next, cS_next, deltaS_next)
#output.append(h_next)
#hx = hx_next
#output = torch.stack(output, 0)
#return output,hx
#return hx[2],hx
return hx
评论列表
文章目录