goru.py 文件源码

python
阅读 50 收藏 0 点赞 0 评论 0

项目:URNN-PyTorch 作者: jingli9111 项目源码 文件源码
def forward(self, input_, hx):
        """
        Args:
            input_: A (batch, input_size) tensor containing input
                features.
            hx: initial hidden, where the size of the state is
                (batch, hidden_size).

        Returns:
            newh: Tensors containing the next hidden state.
        """
        batch_size = hx.size(0)

        bias_batch = (self.gate_bias.unsqueeze(0)
                      .expand(batch_size, *self.gate_bias.size()))
        gate_Wh = torch.addmm(bias_batch, hx, self.gate_W)
        gate_Ux = torch.mm(input_, self.gate_U)     
        r, z = torch.split(gate_Ux + gate_Wh,
                                 split_size=self.hidden_size, dim=1)

        Ux = torch.mm(input_, self.U)
        unitary = self._EUNN(hx=hx, thetaA=self.thetaA, thetaB=self.thetaB)
        unitary = unitary * r
        newh = Ux + unitary
        newh = self._modReLU(newh, self.bias)       
        newh = hx * z + (1-z) * newh

        return newh
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号