def forward(self, input, hidden, ctx, ctx_mask=None):
"""Propogate input through the layer."""
h_0, c_0 = hidden
h_1, c_1 = [], []
for i, layer in enumerate(self.layers):
if ctx_mask is not None:
ctx_mask = torch.ByteTensor(
ctx_mask.data.cpu().numpy().astype(np.int32).tolist()
).cuda()
output, (h_1_i, c_1_i) = layer(input, (h_0, c_0), ctx, ctx_mask)
input = output
if i != len(self.layers):
input = self.dropout(input)
h_1 += [h_1_i]
c_1 += [c_1_i]
h_1 = torch.stack(h_1)
c_1 = torch.stack(c_1)
return input, (h_1, c_1)
评论列表
文章目录