def _init_weight(self): init.xavier_normal(self.w_qs) init.xavier_normal(self.w_ks) init.xavier_normal(self.w_vs) init.xavier_normal(self.w_o.weight)