def test_regex_matches_are_initialized_correctly(self):
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.linear_1_with_funky_name = torch.nn.Linear(5, 10)
self.linear_2 = torch.nn.Linear(10, 5)
self.conv = torch.nn.Conv1d(5, 5, 5)
def forward(self, inputs): # pylint: disable=arguments-differ
pass
# pyhocon does funny things if there's a . in a key. This test makes sure that we
# handle these kinds of regexes correctly.
json_params = """{"initializer": [
["conv", {"type": "constant", "val": 5}],
["funky_na.*bi", {"type": "constant", "val": 7}]
]}
"""
params = Params(pyhocon.ConfigFactory.parse_string(json_params))
initializers = InitializerApplicator.from_params(params['initializer'])
model = Net()
initializers(model)
for parameter in model.conv.parameters():
assert torch.equal(parameter.data, torch.ones(parameter.size()) * 5)
parameter = model.linear_1_with_funky_name.bias
assert torch.equal(parameter.data, torch.ones(parameter.size()) * 7)
评论列表
文章目录