urnn.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:URNN-PyTorch 作者: jingli9111 项目源码 文件源码
def forward(self, input_, hx):
        """
        Args:
            input_: A (batch, input_size) tensor containing input
                features.
            hx: initial hidden, where the size of the state is
                (batch, hidden_size).

        Returns:
            newh: Tensors containing the next hidden state.
        """
        batch_size = hx.size(0)
        Ux = torch.mm(input_, self.U)
        hx = Ux + hx
        newh = self._EUNN(hx=hx, thetaA=self.thetaA, thetaB=self.thetaB)
        newh = self._modReLU(newh, self.bias)
        return newh
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号