def __call__(self, x, train):
param_num = 0
for name, f in self.forward:
if 'conv1' in name:
x = getattr(self, name)(x)
param_num += (f.W.shape[0]*f.W.shape[2]*f.W.shape[3]*f.W.shape[1]+f.W.shape[0])
elif 'bn1' in name:
x = getattr(self, name)(x, not train)
param_num += x.data.shape[1]*2
return (F.relu(x), param_num)
# [(CONV -> Batch -> ReLU -> CONV -> Batch) + (x)]
评论列表
文章目录