highway.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:LasagneNLP 作者: XuezheMax 项目源码 文件源码
def get_output_for(self, input, **kwargs):
        # if the input has more than two dimensions, flatten it into a
        # batch of feature vectors.
        input_reshape = input.flatten(2) if input.ndim > 2 else input

        activation = T.dot(input_reshape, self.W_h)
        if self.b_h is not None:
            activation = activation + self.b_h.dimshuffle('x', 0)
            activation = self.nonlinearity(activation)

        transform = T.dot(input_reshape, self.W_t)
        if self.b_t is not None:
            transform = transform + self.b_t.dimshuffle('x', 0)
            transform = nonlinearities.sigmoid(transform)

        carry = 1.0 - transform

        output = activation * transform + input_reshape * carry
        # reshape output back to orignal input_shape
        if input.ndim > 2:
            output = T.reshape(output, input.shape)

        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号