def encode(self, x):
x = x.unsqueeze(1)
x = self.conv(x)
# At this point x should have shape
# (batch, channels, time, freq)
x = torch.transpose(x, 1, 2).contiguous()
# Reshape x to be (batch, time, freq * channels)
# for the RNN
b, t, f, c = x.size()
x = x.view((b, t, f * c))
x, h = self.rnn(x)
if self.rnn.bidirectional:
half = x.size()[-1] // 2
x = x[:, :, :half] + x[:, :, half:]
return x
评论列表
文章目录