Skip to content

Commit 9ff023e

Browse files
authored
fix: Fix filtering by partition key fails for importing data (milvus-io#33274)
Before executing the import, partition IDs should be reordered according to partition names. Otherwise, the data might be hashed to the wrong partition during import. This PR corrects this error. issue: milvus-io#33237 --------- Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
1 parent be77ceb commit 9ff023e

6 files changed

Lines changed: 234 additions & 36 deletions

File tree

internal/proxy/impl.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6100,7 +6100,11 @@ func (node *Proxy) ImportV2(ctx context.Context, req *internalpb.ImportRequest)
61006100
resp.Status = merr.Status(err)
61016101
return resp, nil
61026102
}
6103-
partitionIDs = lo.Values(partitions)
6103+
_, partitionIDs, err = typeutil.RearrangePartitionsForPartitionKey(partitions)
6104+
if err != nil {
6105+
resp.Status = merr.Status(err)
6106+
return resp, nil
6107+
}
61046108
} else {
61056109
if req.GetPartitionName() == "" {
61066110
req.PartitionName = Params.CommonCfg.DefaultPartitionName.GetValue()

internal/proxy/msg_pack.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ func repackInsertDataWithPartitionKey(ctx context.Context,
231231
}
232232

233233
channel2RowOffsets := assignChannelsByPK(result.IDs, channelNames, insertMsg)
234-
partitionNames, err := getDefaultPartitionNames(ctx, insertMsg.GetDbName(), insertMsg.CollectionName)
234+
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, insertMsg.GetDbName(), insertMsg.CollectionName)
235235
if err != nil {
236236
log.Warn("get default partition names failed in partition key mode",
237237
zap.String("collectionName", insertMsg.CollectionName),

internal/proxy/task_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3135,7 +3135,7 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
31353135
// check default partitions
31363136
err = InitMetaCache(ctx, rc, nil, nil)
31373137
assert.NoError(t, err)
3138-
partitionNames, err := getDefaultPartitionNames(ctx, "", task.CollectionName)
3138+
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, "", task.CollectionName)
31393139
assert.NoError(t, err)
31403140
assert.Equal(t, task.GetNumPartitions(), int64(len(partitionNames)))
31413141

internal/proxy/util.go

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ func hasParitionKeyModeField(schema *schemapb.CollectionSchema) bool {
13951395
return false
13961396
}
13971397

1398-
// getDefaultPartitionNames only used in partition key mode
1398+
// getDefaultPartitionsInPartitionKeyMode only used in partition key mode
13991399
func getDefaultPartitionsInPartitionKeyMode(ctx context.Context, dbName string, collectionName string) ([]string, error) {
14001400
partitions, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
14011401
if err != nil {
@@ -1411,32 +1411,6 @@ func getDefaultPartitionsInPartitionKeyMode(ctx context.Context, dbName string,
14111411
return partitionNames, nil
14121412
}
14131413

1414-
// getDefaultPartitionNames only used in partition key mode
1415-
func getDefaultPartitionNames(ctx context.Context, dbName string, collectionName string) ([]string, error) {
1416-
partitions, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
1417-
if err != nil {
1418-
return nil, err
1419-
}
1420-
1421-
// Make sure the order of the partition names got every time is the same
1422-
partitionNames := make([]string, len(partitions))
1423-
for partitionName := range partitions {
1424-
splits := strings.Split(partitionName, "_")
1425-
if len(splits) < 2 {
1426-
err = fmt.Errorf("bad default partion name in partition ket mode: %s", partitionName)
1427-
return nil, err
1428-
}
1429-
index, err := strconv.ParseInt(splits[len(splits)-1], 10, 64)
1430-
if err != nil {
1431-
return nil, err
1432-
}
1433-
1434-
partitionNames[index] = partitionName
1435-
}
1436-
1437-
return partitionNames, nil
1438-
}
1439-
14401414
func assignChannelsByPK(pks *schemapb.IDs, channelNames []string, insertMsg *msgstream.InsertMsg) map[string][]int {
14411415
insertMsg.HashValues = typeutil.HashPK2Channels(pks, channelNames)
14421416

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
// Licensed to the LF AI & Data foundation under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
package importv2
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"math/rand"
23+
"os"
24+
"strings"
25+
"time"
26+
27+
"github.com/golang/protobuf/proto"
28+
"github.com/samber/lo"
29+
"go.uber.org/zap"
30+
31+
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
32+
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
33+
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
34+
"github.com/milvus-io/milvus/internal/proto/internalpb"
35+
"github.com/milvus-io/milvus/pkg/common"
36+
"github.com/milvus-io/milvus/pkg/log"
37+
"github.com/milvus-io/milvus/pkg/util/funcutil"
38+
"github.com/milvus-io/milvus/pkg/util/merr"
39+
"github.com/milvus-io/milvus/pkg/util/metric"
40+
"github.com/milvus-io/milvus/tests/integration"
41+
)
42+
43+
func (s *BulkInsertSuite) TestImportWithPartitionKey() {
44+
const (
45+
rowCount = 10000
46+
)
47+
48+
c := s.Cluster
49+
ctx, cancel := context.WithTimeout(c.GetContext(), 60*time.Second)
50+
defer cancel()
51+
52+
collectionName := "TestBulkInsert_WithPartitionKey_" + funcutil.GenRandomStr()
53+
54+
schema := integration.ConstructSchema(collectionName, dim, true, &schemapb.FieldSchema{
55+
FieldID: 100,
56+
Name: integration.Int64Field,
57+
IsPrimaryKey: true,
58+
DataType: schemapb.DataType_Int64,
59+
AutoID: true,
60+
}, &schemapb.FieldSchema{
61+
FieldID: 101,
62+
Name: integration.FloatVecField,
63+
DataType: schemapb.DataType_FloatVector,
64+
TypeParams: []*commonpb.KeyValuePair{
65+
{
66+
Key: common.DimKey,
67+
Value: fmt.Sprintf("%d", dim),
68+
},
69+
},
70+
}, &schemapb.FieldSchema{
71+
FieldID: 102,
72+
Name: integration.VarCharField,
73+
DataType: schemapb.DataType_VarChar,
74+
TypeParams: []*commonpb.KeyValuePair{
75+
{
76+
Key: common.MaxLengthKey,
77+
Value: fmt.Sprintf("%d", 256),
78+
},
79+
},
80+
IsPartitionKey: true,
81+
})
82+
marshaledSchema, err := proto.Marshal(schema)
83+
s.NoError(err)
84+
85+
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
86+
DbName: "",
87+
CollectionName: collectionName,
88+
Schema: marshaledSchema,
89+
ShardsNum: common.DefaultShardsNum,
90+
})
91+
s.NoError(err)
92+
s.Equal(int32(0), createCollectionStatus.GetCode())
93+
94+
// create index
95+
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
96+
CollectionName: collectionName,
97+
FieldName: integration.FloatVecField,
98+
IndexName: "_default",
99+
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2),
100+
})
101+
s.NoError(err)
102+
s.Equal(int32(0), createIndexStatus.GetCode())
103+
104+
s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField)
105+
106+
// import
107+
var files []*internalpb.ImportFile
108+
err = os.MkdirAll(c.ChunkManager.RootPath(), os.ModePerm)
109+
s.NoError(err)
110+
111+
filePath := fmt.Sprintf("/tmp/test_%d.parquet", rand.Int())
112+
insertData, err := GenerateParquetFileAndReturnInsertData(filePath, schema, rowCount)
113+
s.NoError(err)
114+
defer os.Remove(filePath)
115+
files = []*internalpb.ImportFile{
116+
{
117+
Paths: []string{
118+
filePath,
119+
},
120+
},
121+
}
122+
123+
importResp, err := c.Proxy.ImportV2(ctx, &internalpb.ImportRequest{
124+
CollectionName: collectionName,
125+
Files: files,
126+
})
127+
s.NoError(err)
128+
s.Equal(int32(0), importResp.GetStatus().GetCode())
129+
log.Info("Import result", zap.Any("importResp", importResp))
130+
131+
jobID := importResp.GetJobID()
132+
err = WaitForImportDone(ctx, c, jobID)
133+
s.NoError(err)
134+
135+
// load
136+
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
137+
CollectionName: collectionName,
138+
})
139+
s.NoError(err)
140+
s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
141+
s.WaitForLoad(ctx, collectionName)
142+
143+
segments, err := c.MetaWatcher.ShowSegments()
144+
s.NoError(err)
145+
s.NotEmpty(segments)
146+
log.Info("Show segments", zap.Any("segments", segments))
147+
148+
// load refresh
149+
loadStatus, err = c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
150+
CollectionName: collectionName,
151+
Refresh: true,
152+
})
153+
s.NoError(err)
154+
s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
155+
s.WaitForLoadRefresh(ctx, "", collectionName)
156+
157+
// query partition key, TermExpr
158+
queryNum := 10
159+
partitionKeyData := insertData.Data[int64(102)].GetRows().([]string)
160+
queryData := partitionKeyData[:queryNum]
161+
strs := lo.Map(queryData, func(str string, _ int) string {
162+
return fmt.Sprintf("\"%s\"", str)
163+
})
164+
str := strings.Join(strs, `,`)
165+
expr := fmt.Sprintf("%s in [%v]", integration.VarCharField, str)
166+
queryResult, err := c.Proxy.Query(ctx, &milvuspb.QueryRequest{
167+
CollectionName: collectionName,
168+
Expr: expr,
169+
OutputFields: []string{integration.VarCharField},
170+
})
171+
err = merr.CheckRPCCall(queryResult, err)
172+
s.NoError(err)
173+
for _, data := range queryResult.GetFieldsData() {
174+
if data.GetType() == schemapb.DataType_VarChar {
175+
resData := data.GetScalars().GetStringData().GetData()
176+
s.Equal(queryNum, len(resData))
177+
s.ElementsMatch(resData, queryData)
178+
}
179+
}
180+
181+
// query partition key, CmpOp 1
182+
expr = fmt.Sprintf("%s >= 0", integration.Int64Field)
183+
queryResult, err = c.Proxy.Query(ctx, &milvuspb.QueryRequest{
184+
CollectionName: collectionName,
185+
Expr: expr,
186+
OutputFields: []string{integration.VarCharField},
187+
})
188+
err = merr.CheckRPCCall(queryResult, err)
189+
s.NoError(err)
190+
for _, data := range queryResult.GetFieldsData() {
191+
if data.GetType() == schemapb.DataType_VarChar {
192+
resData := data.GetScalars().GetStringData().GetData()
193+
s.Equal(rowCount, len(resData))
194+
s.ElementsMatch(resData, partitionKeyData)
195+
}
196+
}
197+
198+
// query partition key, CmpOp 2
199+
target := partitionKeyData[rand.Intn(rowCount)]
200+
expr = fmt.Sprintf("%s == \"%s\"", integration.VarCharField, target)
201+
queryResult, err = c.Proxy.Query(ctx, &milvuspb.QueryRequest{
202+
CollectionName: collectionName,
203+
Expr: expr,
204+
OutputFields: []string{integration.VarCharField},
205+
})
206+
err = merr.CheckRPCCall(queryResult, err)
207+
s.NoError(err)
208+
for _, data := range queryResult.GetFieldsData() {
209+
if data.GetType() == schemapb.DataType_VarChar {
210+
resData := data.GetScalars().GetStringData().GetData()
211+
s.Equal(1, len(resData))
212+
s.Equal(resData[0], target)
213+
}
214+
}
215+
}

tests/integration/import/util_test.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,33 +46,38 @@ import (
4646
const dim = 128
4747

4848
func GenerateParquetFile(filePath string, schema *schemapb.CollectionSchema, numRows int) error {
49+
_, err := GenerateParquetFileAndReturnInsertData(filePath, schema, numRows)
50+
return err
51+
}
52+
53+
func GenerateParquetFileAndReturnInsertData(filePath string, schema *schemapb.CollectionSchema, numRows int) (*storage.InsertData, error) {
4954
w, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666)
5055
if err != nil {
51-
return err
56+
return nil, err
5257
}
5358

5459
pqSchema, err := pq.ConvertToArrowSchema(schema)
5560
if err != nil {
56-
return err
61+
return nil, err
5762
}
5863
fw, err := pqarrow.NewFileWriter(pqSchema, w, parquet.NewWriterProperties(parquet.WithMaxRowGroupLength(int64(numRows))), pqarrow.DefaultWriterProps())
5964
if err != nil {
60-
return err
65+
return nil, err
6166
}
6267
defer fw.Close()
6368

6469
insertData, err := testutil.CreateInsertData(schema, numRows)
6570
if err != nil {
66-
return err
71+
return nil, err
6772
}
6873

6974
columns, err := testutil.BuildArrayData(schema, insertData)
7075
if err != nil {
71-
return err
76+
return nil, err
7277
}
7378

7479
recordBatch := array.NewRecord(pqSchema, columns, int64(numRows))
75-
return fw.Write(recordBatch)
80+
return insertData, fw.Write(recordBatch)
7681
}
7782

7883
func GenerateNumpyFiles(cm storage.ChunkManager, schema *schemapb.CollectionSchema, rowCount int) (*internalpb.ImportFile, error) {

0 commit comments

Comments
 (0)