layers.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:crfrnn_layer 作者: HapeMask 项目源码 文件源码
def __init__(self, unary, ref, sxy_bf=70, sc_bf=10, compat_bf=6,
                 sxy_spatial=2, compat_spatial=2, num_iter=5,
                 normalize_final_iter=True, trainable_kernels=False,
                 name=None):

        super(CRFasRNNLayer, self).__init__(incomings=[unary, ref], name=name)

        self.sxy_bf = sxy_bf
        self.sc_bf = sc_bf
        self.compat_bf = compat_bf
        self.sxy_spatial = sxy_spatial
        self.compat_spatial = compat_spatial
        self.num_iter = num_iter
        self.normalize_final_iter = normalize_final_iter

        if ll.get_output_shape(ref)[1] not in [1, 3]:
            raise ValueError("Reference image must be either color or greyscale \
(1 or 3 channels).")

        self.val_dim = ll.get_output_shape(unary)[1]
        # +2 for bilateral grid
        self.ref_dim = ll.get_output_shape(ref)[1] + 2

        if self.ref_dim == 5:
            kstd_bf = np.array([sxy_bf, sxy_bf, sc_bf, sc_bf, sc_bf],
                               np.float32)
        else:
            kstd_bf = np.array([sxy_bf, sxy_bf, sc_bf], np.float32)

        self.kstd_bf = self.add_param(kstd_bf, (self.ref_dim,),
                                      name="kern_std",
                                      trainable=trainable_kernels,
                                      regularizable=False)

        gk = gkern(sxy_spatial, self.val_dim)
        self.W_spatial = self.add_param(gk, gk.shape, name="spatial_kernel",
                                        trainable=trainable_kernels,
                                        regularizable=False)

        if None in (self.val_dim, self.ref_dim):
            raise ValueError("CRF RNN requires known channel dimensions for \
all inputs.")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号