def __init__(self,
num_heads: int,
input_dim: int,
attention_dim: int,
values_dim: int,
output_projection_dim: int = None,
attention_dropout_prob: float = 0.1) -> None:
super(MultiHeadSelfAttention, self).__init__()
self._num_heads = num_heads
self._input_dim = input_dim
self._output_dim = output_projection_dim or input_dim
self._attention_dim = attention_dim
self._values_dim = values_dim
self._query_projections = Parameter(torch.FloatTensor(num_heads, input_dim, attention_dim))
self._key_projections = Parameter(torch.FloatTensor(num_heads, input_dim, attention_dim))
self._value_projections = Parameter(torch.FloatTensor(num_heads, input_dim, values_dim))
self._scale = input_dim ** 0.5
self._output_projection = Linear(num_heads * values_dim,
self._output_dim)
self._attention_dropout = Dropout(attention_dropout_prob)
self.reset_parameters()
评论列表
文章目录