relevance_based.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:nn-patterns 作者: pikinder 项目源码 文件源码
def _get_normalised_relevance_layer(self, layer, feeder):

        def add_epsilon(Zs):
            tmp = (T.cast(Zs >= 0, theano.config.floatX)*2.0 - 1.0)
            return  Zs + self.epsilon * tmp

        if isinstance(layer, L.DenseLayer):
            forward_layer = L.DenseLayer(layer.input_layer,
                                         layer.num_units,
                                         W=layer.W,
                                         b=layer.b,
                                         nonlinearity=None)
        elif isinstance(layer, L.Conv2DLayer):
            forward_layer = L.Conv2DLayer(layer.input_layer,
                                          num_filters=layer.num_filters,
                                          W=layer.W,
                                          b=layer.b,
                                          stride=layer.stride,
                                          filter_size=layer.filter_size,
                                          flip_filters=layer.flip_filters,
                                          untie_biases=layer.untie_biases,
                                          pad=layer.pad,
                                          nonlinearity=None)
        else:
            raise NotImplementedError()

        forward_layer = L.ExpressionLayer(forward_layer,
                                          lambda x: 1.0 / add_epsilon(x))
        feeder = L.ElemwiseMergeLayer([forward_layer, feeder],
                                      merge_function=T.mul)

        return feeder
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号