def feed(self, im):
assert im.dtype == np.uint8
im = cv2.resize(im, dsize=self.im_shape, interpolation=cv2.INTER_AREA)
im = im - 127.5
ss = [None]*len(self.Vs)
cs = [None]*len(self.Vs)
inp = im
for i, Vn in enumerate(self.Vs):
n = i + 1
if self.Vs[i].use_feedback and (i+1) < len(self.Vs):
context = self.gen_context(self.Vs[i+1], self.im_shape[0]/self.Vs[0].Xb/(2**i), self.im_shape[1]/self.Vs[0].Yb/(2**i))
else:
context = None # top level doesn't have any feedback
s, c = self.Vs[i].sparsify(inp, context=context)
ss[i] = s
cs[i] = c
if c is None: # sparsify returns None if a layer isn't trained enough to return a response
break
if n < len(self.Vs):
# input for next level
inp = self.group_NxN_input(c, 2, self.Vs[0].K+1, self.im_shape[0]/self.Vs[0].Xb/(2**i), self.im_shape[1]/self.Vs[0].Yb/(2**i))
return ss, cs
评论列表
文章目录