def test_calculate_gain_leaky_relu_only_accepts_numbers(self):
for param in [True, [1], {'a': 'b'}]:
with self.assertRaises(ValueError):
init.calculate_gain('leaky_relu', param)
python类calculate_gain()的实例源码
def test_calculate_gain_only_accepts_valid_nonlinearities(self):
for n in [2, 5, 25]:
# Generate random strings of lengths that definitely aren't supported
random_string = ''.join([random.choice(string.ascii_lowercase) for i in range(n)])
with self.assertRaises(ValueError):
init.calculate_gain(random_string)
def reset_parameters(self):
tanh_gain = weight_init.calculate_gain('tanh')
linear_gain = weight_init.calculate_gain('linear')
weight_init.xavier_uniform(self.W_s1.data, tanh_gain)
weight_init.xavier_uniform(self.W_s2.data, linear_gain)
def reset_parameters(self):
linear_gain = weight_init.calculate_gain('linear')
weight_init.xavier_uniform(self.W_x.data, linear_gain)
weight_init.xavier_uniform(self.W_y.data, linear_gain)
weight_init.xavier_uniform(self.W_z.data, linear_gain)
def initWeight(self):
for name, params in self.named_parameters():
# weight?xavier????
if 'weight' in name:
init.xavier_uniform(params, init.calculate_gain('relu'))
# bias?0????
else:
init.constant(params, 0)
def _initialize_weights(self):
init.orthogonal(self.conv1.weight, init.calculate_gain('relu'))
init.orthogonal(self.conv2.weight, init.calculate_gain('relu'))
init.orthogonal(self.conv3.weight, init.calculate_gain('relu'))
init.orthogonal(self.conv4.weight)
# Create the super-resolution model by using the above model definition.