def initWeight(self, init_forget_bias=1):
# See details in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
for name, params in self.named_parameters():
# weight?xavier????
if 'weight' in name:
init.xavier_uniform(params)
# ??????????GRU?b_iz, b_hz????
elif 'gru.bias_ih_l' in name:
b_ir, b_iz, b_in = params.chunk(3, 0)
init.constant(b_iz, init_forget_bias)
elif 'gru.bias_hh_l' in name:
b_hr, b_hz, b_hn = params.chunk(3, 0)
init.constant(b_hz, init_forget_bias)
# ?????bias?0????
else:
init.constant(params, 0)
评论列表
文章目录