def __init__(self, n_head, d_model, dropout):
super().__init__()
self.n_head = n_head
self.d_v = self.d_k = d_k = d_model // n_head
for name in ["w_qs", "w_ks", "w_vs"]:
self.__setattr__(name,
nn.Parameter(torch.FloatTensor(n_head, d_model, d_k)))
self.attention = ScaledDotProductAttention(d_k, dropout)
self.lm = LayerNorm(d_model)
self.w_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = dropout
self._init_weight()
评论列表
文章目录