def mask_prev_h(self, prev_h):
if self.len_wave_mask is not None:
if self.use_gpu:
self.len_wave_mask = self.len_wave_mask.cuda()
h_att, h_dec1, h_dec2 = prev_h
h_att = torch.index_select(h_att.data, 1, self.len_wave_mask) # batch idx is
h_dec1 = torch.index_select(h_dec1.data, 1, self.len_wave_mask)
h_dec2 = torch.index_select(h_dec2.data, 1, self.len_wave_mask)
prev_h = (Variable(h_att), Variable(h_dec1), Variable(h_dec2))
else:
prev_h = prev_h
return prev_h
评论列表
文章目录