def reset(self, stdv=None):
if stdv is not None:
stdv = stdv * math.sqrt(3)
self.weight.uniform_(-stdv, stdv)
self.bias.uniform_(-stdv, stdv)
else:
ninp = torch.Tensor(self.nOutputPlane).zero_()
for i in range(self.connTable.size(0)):
idx = int(self.connTable[i, 1])
ninp[idx] += 1
for k in range(self.connTable.size(0)):
idx = int(self.connTable[k, 1])
stdv = 1. / math.sqrt(self.kW * self.kH * ninp[idx])
self.weight.select(0, k).uniform_(-stdv, stdv)
for k in range(self.bias.size(0)):
stdv = 1. / math.sqrt(self.kW * self.kH * ninp[k])
# TODO: torch.uniform
self.bias[k] = random.uniform(-stdv, stdv)
评论列表
文章目录