def __call__(self, x, test=False, retain_forward=False):
h = self.c_first(x, test=test, retain_forward=retain_forward)
for i in range(self.down_layers-1):
h = getattr(self, 'c'+str(i))(h, test=test, retain_forward=retain_forward)
if not self.conv_as_last:
_b, _ch, _w, _h = h.data.shape
self.last_shape=(_b, _ch, _w, _h)
h = F.reshape(h, (_b, _ch*_w*_h))
h = self.c_last(h, test=test, retain_forward=retain_forward)
return h
评论列表
文章目录