def make_initializer(
linear={'type': 'uniform', 'args': {'a': -0.05, 'b': 0.05}},
linear_bias={'type': 'constant', 'args': {'val': 0.}},
rnn={'type': 'xavier_uniform', 'args': {'gain': 1.}},
rnn_bias={'type': 'constant', 'args': {'val': 0.}},
cnn_bias={'type': 'constant', 'args': {'val': 0.}},
emb={'type': 'normal', 'args': {'mean': 0, 'std': 1}},
default={'type': 'uniform', 'args': {'a': -0.05, 'b': 0.05}}):
rnns = (torch.nn.LSTM, torch.nn.GRU,
torch.nn.LSTMCell, torch.nn.GRUCell,
StackedGRU, StackedLSTM, NormalizedGRU,
NormalizedGRUCell, StackedNormalizedGRU)
convs = (torch.nn.Conv1d, torch.nn.Conv2d)
def initializer(m):
if isinstance(m, (rnns)): # RNNs
for p_name, p in m.named_parameters():
if hasattr(p, 'custom'):
continue
if is_bias(p_name):
getattr(init, rnn_bias['type'])(p, **rnn_bias['args'])
else:
getattr(init, rnn['type'])(p, **rnn['args'])
elif isinstance(m, torch.nn.Linear): # linear
for p_name, p in m.named_parameters():
if hasattr(p, 'custom'):
continue
if is_bias(p_name):
getattr(init, linear_bias['type'])(p, **linear_bias['args'])
else:
getattr(init, linear['type'])(p, **linear['args'])
elif isinstance(m, torch.nn.Embedding): # embedding
for p in m.parameters():
if hasattr(p, 'custom'):
continue
getattr(init, emb['type'])(p, **emb['args'])
elif isinstance(m, convs):
for p_name, p in m.named_parameters():
if hasattr(p, 'custom'):
continue
if is_bias(p_name):
getattr(init, cnn_bias['type'])(p, **cnn_bias['args'])
else:
# Karpathy: http://cs231n.github.io/neural-networks-2/#init
# -> scale weight vector by square root of its fan-in...
# fan_in, _ = init._calculate_fan_in_and_fan_out(p)
# init.normal(p, mean=0, std=math.sqrt(fan_in))
init.xavier_normal(p)
return initializer
评论列表
文章目录