def __init__(self, unary, ref, sxy_bf=70, sc_bf=10, compat_bf=6,
sxy_spatial=2, compat_spatial=2, num_iter=5,
normalize_final_iter=True, trainable_kernels=False,
name=None):
super(CRFasRNNLayer, self).__init__(incomings=[unary, ref], name=name)
self.sxy_bf = sxy_bf
self.sc_bf = sc_bf
self.compat_bf = compat_bf
self.sxy_spatial = sxy_spatial
self.compat_spatial = compat_spatial
self.num_iter = num_iter
self.normalize_final_iter = normalize_final_iter
if ll.get_output_shape(ref)[1] not in [1, 3]:
raise ValueError("Reference image must be either color or greyscale \
(1 or 3 channels).")
self.val_dim = ll.get_output_shape(unary)[1]
# +2 for bilateral grid
self.ref_dim = ll.get_output_shape(ref)[1] + 2
if self.ref_dim == 5:
kstd_bf = np.array([sxy_bf, sxy_bf, sc_bf, sc_bf, sc_bf],
np.float32)
else:
kstd_bf = np.array([sxy_bf, sxy_bf, sc_bf], np.float32)
self.kstd_bf = self.add_param(kstd_bf, (self.ref_dim,),
name="kern_std",
trainable=trainable_kernels,
regularizable=False)
gk = gkern(sxy_spatial, self.val_dim)
self.W_spatial = self.add_param(gk, gk.shape, name="spatial_kernel",
trainable=trainable_kernels,
regularizable=False)
if None in (self.val_dim, self.ref_dim):
raise ValueError("CRF RNN requires known channel dimensions for \
all inputs.")
评论列表
文章目录