def forward(self, H_j_dec, input_x):
if torch.has_cudnn:
# Input is of the shape : (B, T, N)
input_x = Variable(torch.from_numpy(input_x[:, self._L:-self._L, :]).cuda(), requires_grad=True)
else:
# Input is of the shape : (B, T, N)
# Cropping some "un-necessary" frequency sub-bands
input_x = Variable(torch.from_numpy(input_x[:, self._L:-self._L, :]), requires_grad=True)
# Decode/Sparsify mask
mask_t1 = self.relu(self.ffDec(H_j_dec))
# Apply skip-filtering connections
Y_j = torch.mul(mask_t1, input_x)
return Y_j, mask_t1
评论列表
文章目录