def _transform_decoder_init_state(self, hn):
if isinstance(hn, tuple):
hn, cn = hn
# hn [2 * num_layers, batch, hidden_size]
num_dir, batch, hidden_size = cn.size()
# first convert cn t0 [batch, 2 * num_layers, hidden_size]
cn = cn.transpose(0, 1).contiguous()
# then view to [batch, num_layers, 2 * hidden_size] --> [num_layer, batch, 2 * num_layers]
cn = cn.view(batch, num_dir / 2, 2 * hidden_size).transpose(0, 1)
# take hx_dense to [num_layers, batch, hidden_size]
cn = self.hx_dense(cn)
# hn is tanh(cn)
hn = F.tanh(cn)
hn = (hn, cn)
else:
# hn [2 * num_layers, batch, hidden_size]
num_dir, batch, hidden_size = hn.size()
# first convert hn t0 [batch, 2 * num_layers, hidden_size]
hn = hn.transpose(0, 1).contiguous()
# then view to [batch, num_layers, 2 * hidden_size] --> [num_layer, batch, 2 * num_layers]
hn = hn.view(batch, num_dir / 2, 2 * hidden_size).transpose(0, 1)
# take hx_dense to [num_layers, batch, hidden_size]
hn = F.tanh(self.hx_dense(hn))
return hn
评论列表
文章目录