def get_output_for(self, inputs, **kwargs):
vals, ref = inputs
def filt(V, R):
if self.norm_type is not None:
o = tt.ones((1, V.shape[1], V.shape[2]), np.float32)
norm = gaussian_filter(R, o, self.kern_std, self.ref_dim)
norm = tt.sqrt(norm) if self.norm_type == "sym" else norm
norm += 1e-8
V = V / norm if self.norm_type in ["pre", "sym"] else V
F = gaussian_filter(R, V, self.kern_std)
return F / norm if self.norm_type in ["post", "sym"] else F
filtered = theano.scan(fn=filt, sequences=[vals, ref],
outputs_info=None)[0]
return filtered
评论列表
文章目录