作者:rmason
项目:hei
func (m *accountManager) GrantStaff(
ctx scope.Context, accountID snowflake.Snowflake, kmsCred security.KMSCredential) error {
m.b.Lock()
defer m.b.Unlock()
account, ok := m.b.accounts[accountID]
if !ok {
return proto.ErrAccountNotFound
}
memAcc := account.(*memAccount)
kms := kmsCred.KMS()
key := memAcc.sec.SystemKey.Clone()
if err := kms.DecryptKey(&key); err != nil {
return err
}
nonce, err := kms.GenerateNonce(key.KeyType.BlockSize())
if err != nil {
return err
}
capability, err := security.GrantSharedSecretCapability(&key, nonce, kmsCred.KMSType(), kmsCred)
if err != nil {
return err
}
memAcc.staffCapability = capability
return nil
}
作者:NotAMoos
项目:hei
func (b *AccountManagerBinding) GrantStaff(
ctx scope.Context, accountID snowflake.Snowflake, kmsCred security.KMSCredential) error {
// Look up the target account's (system) encrypted client key. This is
// not part of the transaction, because we want to interact with KMS
// before we proceed. That should be fine, since this is an infrequently
// used action.
var row struct {
EncryptedClientKey []byte `db:"encrypted_system_key"`
Nonce []byte `db:"nonce"`
}
err := b.DbMap.SelectOne(
&row, "SELECT encrypted_system_key, nonce FROM account WHERE id = $1", accountID.String())
if err != nil {
if err == sql.ErrNoRows {
return proto.ErrAccountNotFound
}
return err
}
// Use kmsCred to obtain kms and decrypt the client's key.
kms := kmsCred.KMS()
clientKey := &security.ManagedKey{
KeyType: proto.ClientKeyType,
Ciphertext: row.EncryptedClientKey,
ContextKey: "nonce",
ContextValue: base64.URLEncoding.EncodeToString(row.Nonce),
}
if err := kms.DecryptKey(clientKey); err != nil {
return err
}
// Grant staff capability. This involves marshalling kmsCred to JSON and
// encrypting it with the client key.
nonce, err := kms.GenerateNonce(clientKey.KeyType.BlockSize())
if err != nil {
return err
}
capability, err := security.GrantSharedSecretCapability(clientKey, nonce, kmsCred.KMSType(), kmsCred)
if err != nil {
return err
}
// Store capability and update account table.
t, err := b.DbMap.Begin()
if err != nil {
return err
}
rollback := func() {
if err := t.Rollback(); err != nil {
backend.Logger(ctx).Printf("rollback error: %s", err)
}
}
dbCap := &Capability{
ID: capability.CapabilityID(),
NonceBytes: capability.Nonce(),
EncryptedPrivateData: capability.EncryptedPayload(),
PublicData: capability.PublicPayload(),
}
if err := t.Insert(dbCap); err != nil {
rollback()
return err
}
result, err := t.Exec(
"UPDATE account SET staff_capability_id = $2 WHERE id = $1",
accountID.String(), capability.CapabilityID())
if err != nil {
rollback()
return err
}
n, err := result.RowsAffected()
if err != nil {
rollback()
return err
}
if n != 1 {
rollback()
return proto.ErrAccountNotFound
}
if err := t.Commit(); err != nil {
return err
}
return nil
}