def forward(self, input, lengths):
N, T = input.size(0), input.size(1)
conv_bank_out = []
input_t = input.transpose(1, 2) # NxTxH -> NxHxT
for i in range(self.num_filters):
tmp_input = input_t
if i % 2 == 0:
tmp_input = tmp_input.unsqueeze(-1)
tmp_input = F.pad(tmp_input, (0,0,0,1)).squeeze(-1) # NxHxT
conv_bank_out.append(self.conv_bank[i](tmp_input))
residual = torch.cat(conv_bank_out, dim=1) # NxHFxT
residual = F.relu(self.bn_list[0](residual))
residual = F.max_pool1d(residual, 2, stride=1)
residual = self.conv1(residual) # NxHxT
residual = F.relu(self.bn_list[1](residual))
residual = self.conv2(residual) # NxHxT
residual = self.bn_list[2](residual).transpose(1,2) # NxHxT -> NxTxH
rnn_input = input
if rnn_input.size() != residual.size():
rnn_input = self.residual_proj(rnn_input)
rnn_input = rnn_input + residual
rnn_input = self.highway(rnn_input).view(N, T, -1)
output = rnn.pack_padded_sequence(rnn_input, lengths, True)
output, _ = self.BGRU(output) # zero h_0 is used by default
output, _ = rnn.pad_packed_sequence(output, True) # NxTx2H
return output
评论列表
文章目录