Skip to content

Commit 7a07013

Browse files
authored
feat: implement KMS key revocation deny read/write (milvus-io#45936)
Adds KMS key state monitoring and deny reading/writing for abnormal kms Key Changes: - Add KeyManager in RootCoord for periodic KMS state polling - Integrate KeyManager with QuotaCenter for access denial - Implement revocation checks in Proxy SimpleLimiter Access Denial: - Revoked keys: Release collections + deny DML/DQL (DDL still allowed) - Manual LoadCollection required after key recovery See also: milvus-io#45117, milvus-io#44981, milvus-io#45242 --------- Signed-off-by: yangxuan <xuan.yang@zilliz.com>
1 parent 7b1efc7 commit 7a07013

21 files changed

Lines changed: 676 additions & 291 deletions

internal/mocks/flushcommon/mock_util/mock_MsgHandler.go

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/proxy/simple_rate_limiter.go

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,23 +162,27 @@ func isNotCollectionLevelLimitRequest(rt internalpb.RateType) bool {
162162
func (m *SimpleLimiter) GetQuotaStates() ([]milvuspb.QuotaState, []string) {
163163
m.quotaStatesMu.RLock()
164164
defer m.quotaStatesMu.RUnlock()
165-
serviceStates := make(map[milvuspb.QuotaState]typeutil.Set[commonpb.ErrorCode])
165+
type stateReasonKey struct {
166+
ErrorCode commonpb.ErrorCode
167+
Reason string
168+
}
169+
serviceStates := make(map[milvuspb.QuotaState]typeutil.Set[stateReasonKey])
166170

167171
rlinternal.TraverseRateLimiterTree(m.rateLimiter.GetRootLimiters(), nil,
168-
func(node *rlinternal.RateLimiterNode, state milvuspb.QuotaState, errCode commonpb.ErrorCode) bool {
172+
func(node *rlinternal.RateLimiterNode, state milvuspb.QuotaState, errCode commonpb.ErrorCode, reason string) bool {
169173
if serviceStates[state] == nil {
170-
serviceStates[state] = typeutil.NewSet[commonpb.ErrorCode]()
174+
serviceStates[state] = typeutil.NewSet[stateReasonKey]()
171175
}
172-
serviceStates[state].Insert(errCode)
176+
serviceStates[state].Insert(stateReasonKey{ErrorCode: errCode, Reason: reason})
173177
return true
174178
})
175179

176180
states := make([]milvuspb.QuotaState, 0)
177181
reasons := make([]string, 0)
178-
for state, errCodes := range serviceStates {
179-
for errCode := range errCodes {
182+
for state, stateReasonKeys := range serviceStates {
183+
for key := range stateReasonKeys {
180184
states = append(states, state)
181-
reasons = append(reasons, ratelimitutil.GetQuotaErrorString(errCode))
185+
reasons = append(reasons, ratelimitutil.GetQuotaErrorStringWithReason(key.ErrorCode, key.Reason))
182186
}
183187
}
184188

@@ -285,11 +289,19 @@ func (m *SimpleLimiter) updateLimiterNode(req *proxypb.Limiter, node *rlinternal
285289
limit.SetLimit(ratelimitutil.Limit(rate.GetR()))
286290
setRateGaugeByRateType(rate.GetRt(), paramtable.GetNodeID(), sourceID, rate.GetR())
287291
}
288-
quotaStates := typeutil.NewConcurrentMap[milvuspb.QuotaState, commonpb.ErrorCode]()
292+
quotaStates := typeutil.NewConcurrentMap[milvuspb.QuotaState, *rlinternal.QuotaStateInfo]()
289293
states := req.GetStates()
290294
codes := req.GetCodes()
295+
reasons := req.GetReasons()
291296
for i, state := range states {
292-
quotaStates.Insert(state, codes[i])
297+
reason := ""
298+
if i < len(reasons) {
299+
reason = reasons[i]
300+
}
301+
quotaStates.Insert(state, &rlinternal.QuotaStateInfo{
302+
ErrorCode: codes[i],
303+
Reason: reason,
304+
})
293305
}
294306
node.SetQuotaStates(quotaStates)
295307
return nil

internal/rootcoord/ddl_callbacks_alter_collection_properties_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import (
2525
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
2626
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
2727
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
28-
"github.com/milvus-io/milvus/internal/util/hookutil"
2928
"github.com/milvus-io/milvus/pkg/v2/common"
3029
"github.com/milvus-io/milvus/pkg/v2/util"
3130
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
@@ -60,7 +59,7 @@ func TestDDLCallbacksAlterCollectionProperties(t *testing.T) {
6059
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
6160
DbName: dbName,
6261
CollectionName: collectionName,
63-
Properties: []*commonpb.KeyValuePair{{Key: hookutil.EncryptionEnabledKey, Value: "1"}},
62+
Properties: []*commonpb.KeyValuePair{{Key: common.EncryptionEnabledKey, Value: "1"}},
6463
})
6564
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
6665

internal/rootcoord/ddl_callbacks_alter_database.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func (c *DDLCallback) alterDatabaseV1AckCallback(ctx context.Context, result mes
137137
header := result.Message.Header()
138138
body := result.Message.MustBody()
139139

140-
db := model.NewDatabase(header.DbId, header.DbName, etcdpb.DatabaseState_DatabaseCreated, result.Message.MustBody().Properties)
140+
db := model.NewDatabase(header.DbId, header.DbName, etcdpb.DatabaseState_DatabaseCreated, body.Properties)
141141
if err := c.meta.AlterDatabase(ctx, db, result.GetControlChannelResult().TimeTick); err != nil {
142142
return errors.Wrap(err, "failed to alter database")
143143
}

internal/rootcoord/ddl_callbacks_alter_database_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323

2424
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
2525
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
26-
"github.com/milvus-io/milvus/internal/util/hookutil"
2726
"github.com/milvus-io/milvus/pkg/v2/common"
2827
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
2928
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
@@ -54,7 +53,7 @@ func TestDDLCallbacksAlterDatabase(t *testing.T) {
5453
// hook related properties are not allowed to be altered.
5554
resp, err = core.AlterDatabase(ctx, &rootcoordpb.AlterDatabaseRequest{
5655
DbName: dbName,
57-
Properties: []*commonpb.KeyValuePair{{Key: hookutil.EncryptionEnabledKey, Value: "1"}},
56+
Properties: []*commonpb.KeyValuePair{{Key: common.EncryptionEnabledKey, Value: "1"}},
5857
})
5958
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
6059

internal/rootcoord/key_manager.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// Licensed to the LF AI & Data foundation under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
package rootcoord
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"strconv"
23+
24+
"github.com/samber/lo"
25+
"go.uber.org/zap"
26+
27+
"github.com/milvus-io/milvus/internal/metastore/model"
28+
"github.com/milvus-io/milvus/internal/util/hookutil"
29+
"github.com/milvus-io/milvus/pkg/v2/common"
30+
"github.com/milvus-io/milvus/pkg/v2/log"
31+
"github.com/milvus-io/milvus/pkg/v2/util"
32+
)
33+
34+
type KeyManager struct {
35+
ctx context.Context
36+
meta IMetaTable
37+
enabled bool
38+
}
39+
40+
func NewKeyManager(
41+
ctx context.Context,
42+
meta IMetaTable,
43+
) *KeyManager {
44+
if hookutil.GetCipherWithState() == nil {
45+
log.Info("KeyManager disabled (cipher plugin not loaded)")
46+
return nil
47+
}
48+
log.Info("KeyManager enabled")
49+
return &KeyManager{
50+
ctx: ctx,
51+
meta: meta,
52+
}
53+
}
54+
55+
func (km *KeyManager) GetRevokedDatabases() ([]int64, error) {
56+
currentStates, err := hookutil.GetEzStates()
57+
if err != nil {
58+
return nil, fmt.Errorf("failed to get cipher states: %w", err)
59+
}
60+
61+
abnormalDB := make(map[int64]struct{})
62+
for ezID, currentState := range currentStates {
63+
if currentState != hookutil.KeyStateEnabled {
64+
db, err := km.getDatabaseByEzID(ezID)
65+
if err != nil {
66+
log.Warn("KeyManager: failed to get database for ezID", zap.Int64("ezID", ezID), zap.Error(err))
67+
continue
68+
}
69+
70+
abnormalDB[db.ID] = struct{}{}
71+
}
72+
}
73+
74+
revokedDBIDs := lo.Keys(abnormalDB)
75+
return revokedDBIDs, nil
76+
}
77+
78+
func (km *KeyManager) getDatabaseByEzID(ezID int64) (*model.Database, error) {
79+
// use ezID as dbID to get database
80+
db, err := km.meta.GetDatabaseByID(km.ctx, ezID, 0)
81+
if err != nil {
82+
// fallback to default database(dbID=1)
83+
db, err = km.meta.GetDatabaseByID(km.ctx, util.DefaultDBID, 0)
84+
if err != nil {
85+
return nil, err
86+
}
87+
}
88+
89+
// verify the ezID matches the retrieved DB
90+
if db.GetProperty(common.EncryptionEzIDKey) != strconv.FormatInt(ezID, 10) {
91+
return nil, fmt.Errorf("db for ezID %d not found", ezID)
92+
}
93+
94+
return db, nil
95+
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// Licensed to the LF AI & Data foundation under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
package rootcoord
18+
19+
import (
20+
"context"
21+
"strconv"
22+
"testing"
23+
24+
"github.com/cockroachdb/errors"
25+
"github.com/stretchr/testify/assert"
26+
27+
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
28+
"github.com/milvus-io/milvus/internal/metastore/model"
29+
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
30+
"github.com/milvus-io/milvus/internal/util/hookutil"
31+
"github.com/milvus-io/milvus/pkg/v2/common"
32+
"github.com/milvus-io/milvus/pkg/v2/util"
33+
)
34+
35+
func TestNewKeyManager(t *testing.T) {
36+
ctx := context.Background()
37+
meta := mockrootcoord.NewIMetaTable(t)
38+
hookutil.InitTestCipher()
39+
40+
km := NewKeyManager(ctx, meta)
41+
42+
assert.NotNil(t, km)
43+
assert.Equal(t, ctx, km.ctx)
44+
assert.Equal(t, meta, km.meta)
45+
}
46+
47+
func TestKeyManager_GetDatabaseByEzID(t *testing.T) {
48+
ctx := context.Background()
49+
hookutil.InitTestCipher()
50+
51+
t.Run("success get database", func(t *testing.T) {
52+
meta := mockrootcoord.NewIMetaTable(t)
53+
54+
expectedDB := &model.Database{
55+
ID: 123,
56+
Name: "test_db",
57+
Properties: []*commonpb.KeyValuePair{
58+
{
59+
Key: common.EncryptionEzIDKey,
60+
Value: "123", // the same as the dbID
61+
},
62+
{
63+
Key: common.EncryptionEnabledKey,
64+
Value: "true",
65+
},
66+
},
67+
}
68+
69+
meta.EXPECT().GetDatabaseByID(ctx, int64(123), uint64(0)).Return(expectedDB, nil).Once()
70+
71+
km := &KeyManager{
72+
ctx: ctx,
73+
meta: meta,
74+
}
75+
76+
db, err := km.getDatabaseByEzID(123)
77+
assert.NoError(t, err)
78+
assert.Equal(t, expectedDB, db)
79+
})
80+
81+
t.Run("fallback to default database", func(t *testing.T) {
82+
meta := mockrootcoord.NewIMetaTable(t)
83+
84+
ezID := int64(19530)
85+
defaultDB := &model.Database{
86+
ID: util.DefaultDBID,
87+
Name: util.DefaultDBName,
88+
Properties: []*commonpb.KeyValuePair{
89+
{
90+
Key: common.EncryptionEzIDKey,
91+
Value: strconv.FormatInt(ezID, 10),
92+
},
93+
{
94+
Key: common.EncryptionEnabledKey,
95+
Value: "true",
96+
},
97+
},
98+
}
99+
100+
meta.EXPECT().GetDatabaseByID(ctx, ezID, uint64(0)).Return(nil, errors.New("db not found")).Once()
101+
meta.EXPECT().GetDatabaseByID(ctx, util.DefaultDBID, uint64(0)).Return(defaultDB, nil).Once()
102+
103+
km := &KeyManager{
104+
ctx: ctx,
105+
meta: meta,
106+
}
107+
108+
db, err := km.getDatabaseByEzID(ezID)
109+
assert.NoError(t, err)
110+
assert.Equal(t, defaultDB, db)
111+
})
112+
}

0 commit comments

Comments
 (0)