def initWeight(self, init_forget_bias=1):
# See https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
for name, params in self.named_parameters():
# weight?xavier????
if 'weight' in name:
init.xavier_uniform(params)
# ??????????LSTM?b_if, b_hf????
elif 'lstm.bias_ih_l' in name:
b_ii, b_if, b_ig, b_i0 = params.chunk(4, 0)
init.constant(b_if, init_forget_bias)
elif 'lstm.bias_hh_l' in name:
b_hi, b_hf, b_hg, b_h0 = params.chunk(4, 0)
init.constant(b_hf, init_forget_bias)
# ?????bias?0????
else:
init.constant(params, 0)
评论列表
文章目录