def prepare_decoding(self, state, lengths, train=True):
state = super().prepare_decoding(state, lengths, train=train)
x = state['x']
h = state['h']
c = F.broadcast_to(self.encoder.c0, (self.batchsize, self.dim_hid))
lengths = lengths.astype(np.float32)
lengths = lengths.reshape((self.batchsize, 1))
c = c * lengths
return {'x': x, 'c': c, 'h': h}
评论列表
文章目录