def weight_init(m): if isinstance(m, nn.Conv2d): init.xavier_normal(m.weight) init.constant(m.bias, 0)