Skip to content

Commit da96659

Browse files
authored
[skip e2e] Add more testcases with different parameter combinations in test_reduce (milvus-io#18967)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com> Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
1 parent dcddc9d commit da96659

1 file changed

Lines changed: 34 additions & 18 deletions

File tree

internal/core/unittest/test_reduce.cpp

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ using SubSearchResultUniq = std::unique_ptr<SubSearchResult>;
2424

2525
std::default_random_engine e(42);
2626

27-
std::unique_ptr<SubSearchResult>
27+
SubSearchResultUniq
2828
GenSubSearchResult(const int64_t nq,
2929
const int64_t topk,
3030
const knowhere::MetricType &metric_type,
@@ -34,8 +34,8 @@ GenSubSearchResult(const int64_t nq,
3434
SubSearchResultUniq sub_result = std::make_unique<SubSearchResult>(nq, topk, metric_type, round_decimal);
3535
std::vector<int64_t> ids;
3636
std::vector<float> distances;
37-
for (int n = 0; n < nq; ++n) {
38-
for (int k = 0; k < topk; ++k) {
37+
for (auto n = 0; n < nq; ++n) {
38+
for (auto k = 0; k < topk; ++k) {
3939
auto gen_x = e() % limit;
4040
ids.push_back(gen_x);
4141
distances.push_back(gen_x);
@@ -57,7 +57,7 @@ template<class queue_type>
5757
void
5858
CheckSubSearchResult(const int64_t nq,
5959
const int64_t topk,
60-
SubSearchResult& search_result,
60+
SubSearchResult& result,
6161
std::vector<queue_type>& result_ref) {
6262
ASSERT_EQ(result_ref.size(), nq);
6363
for (int n = 0; n < nq; ++n) {
@@ -66,8 +66,8 @@ CheckSubSearchResult(const int64_t nq,
6666
auto ref_x = result_ref[n].top();
6767
result_ref[n].pop();
6868
auto index = n * topk + topk - 1 - k;
69-
auto id = search_result.get_seg_offsets()[index];
70-
auto distance = search_result.get_distances()[index];
69+
auto id = result.get_seg_offsets()[index];
70+
auto distance = result.get_distances()[index];
7171
ASSERT_EQ(id, ref_x);
7272
ASSERT_EQ(distance, ref_x);
7373
}
@@ -76,19 +76,19 @@ CheckSubSearchResult(const int64_t nq,
7676

7777
template<class queue_type>
7878
void
79-
TestSubSearchResultMerge(const knowhere::MetricType& metric_type) {
80-
int64_t num_queries = 16;
81-
int64_t topk = 10;
82-
int64_t iteration = 10;
83-
int64_t round_decimal = 3;
79+
TestSubSearchResultMerge(const knowhere::MetricType& metric_type,
80+
const int64_t iteration,
81+
const int64_t nq,
82+
const int64_t topk) {
83+
const int64_t round_decimal = 3;
8484

85-
std::vector<queue_type> result_ref(num_queries);
85+
std::vector<queue_type> result_ref(nq);
8686

87-
SubSearchResult final_result(num_queries, topk, metric_type, round_decimal);
87+
SubSearchResult final_result(nq, topk, metric_type, round_decimal);
8888
for (int i = 0; i < iteration; ++i) {
89-
SubSearchResultUniq sub_result = GenSubSearchResult(num_queries, topk, metric_type, round_decimal);
89+
SubSearchResultUniq sub_result = GenSubSearchResult(nq, topk, metric_type, round_decimal);
9090
auto ids = sub_result->get_ids();
91-
for (int n = 0; n < num_queries; ++n) {
91+
for (int n = 0; n < nq; ++n) {
9292
for (int k = 0; k < topk; ++k) {
9393
int64_t x = ids[n * topk + k];
9494
result_ref[n].push(x);
@@ -99,12 +99,28 @@ TestSubSearchResultMerge(const knowhere::MetricType& metric_type) {
9999
}
100100
final_result.merge(*sub_result);
101101
}
102-
CheckSubSearchResult<queue_type>(num_queries, topk, final_result, result_ref);
102+
CheckSubSearchResult<queue_type>(nq, topk, final_result, result_ref);
103103
}
104104

105105
TEST(Reduce, SubSearchResult) {
106106
using queue_type_l2 = std::priority_queue<int64_t, std::vector<int64_t>, std::less<int64_t>>;
107107
using queue_type_ip = std::priority_queue<int64_t, std::vector<int64_t>, std::greater<int64_t>>;
108-
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2);
109-
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP);
108+
109+
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 1, 1, 1);
110+
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 1, 1, 10);
111+
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 1, 16, 1);
112+
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 1, 16, 10);
113+
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 4, 1, 1);
114+
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 4, 1, 10);
115+
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 4, 16, 1);
116+
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2, 4, 16, 10);
117+
118+
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 1, 1, 1);
119+
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 1, 1, 10);
120+
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 1, 16, 1);
121+
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 1, 16, 10);
122+
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 4, 1, 1);
123+
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 4, 1, 10);
124+
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 4, 16, 1);
125+
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP, 4, 16, 10);
110126
}

0 commit comments

Comments
 (0)