def recollect(self, w):
if w is None:
self.w = w
return
idx = self.keep_idx
k = w['kernel']
b = w['biases']
self.w['kernel'] = np.take(k, idx, 3)
self.w['biases'] = np.take(b, idx)
if self.batch_norm:
m = w['moving_mean']
v = w['moving_variance']
g = w['gamma']
self.w['moving_mean'] = np.take(m, idx)
self.w['moving_variance'] = np.take(v, idx)
self.w['gamma'] = np.take(g, idx)
评论列表
文章目录