@@ -24,7 +24,7 @@ using SubSearchResultUniq = std::unique_ptr<SubSearchResult>;
2424
2525std::default_random_engine e (42 );
2626
27- std::unique_ptr<SubSearchResult>
27+ SubSearchResultUniq
2828GenSubSearchResult (const int64_t nq,
2929 const int64_t topk,
3030 const knowhere::MetricType &metric_type,
@@ -34,8 +34,8 @@ GenSubSearchResult(const int64_t nq,
3434 SubSearchResultUniq sub_result = std::make_unique<SubSearchResult>(nq, topk, metric_type, round_decimal);
3535 std::vector<int64_t > ids;
3636 std::vector<float > distances;
37- for (int n = 0 ; n < nq; ++n) {
38- for (int k = 0 ; k < topk; ++k) {
37+ for (auto n = 0 ; n < nq; ++n) {
38+ for (auto k = 0 ; k < topk; ++k) {
3939 auto gen_x = e () % limit;
4040 ids.push_back (gen_x);
4141 distances.push_back (gen_x);
@@ -57,7 +57,7 @@ template<class queue_type>
5757void
5858CheckSubSearchResult (const int64_t nq,
5959 const int64_t topk,
60- SubSearchResult& search_result ,
60+ SubSearchResult& result ,
6161 std::vector<queue_type>& result_ref) {
6262 ASSERT_EQ (result_ref.size (), nq);
6363 for (int n = 0 ; n < nq; ++n) {
@@ -66,8 +66,8 @@ CheckSubSearchResult(const int64_t nq,
6666 auto ref_x = result_ref[n].top ();
6767 result_ref[n].pop ();
6868 auto index = n * topk + topk - 1 - k;
69- auto id = search_result .get_seg_offsets ()[index];
70- auto distance = search_result .get_distances ()[index];
69+ auto id = result .get_seg_offsets ()[index];
70+ auto distance = result .get_distances ()[index];
7171 ASSERT_EQ (id, ref_x);
7272 ASSERT_EQ (distance, ref_x);
7373 }
@@ -76,19 +76,19 @@ CheckSubSearchResult(const int64_t nq,
7676
7777template <class queue_type >
7878void
79- TestSubSearchResultMerge (const knowhere::MetricType& metric_type) {
80- int64_t num_queries = 16 ;
81- int64_t topk = 10 ;
82- int64_t iteration = 10 ;
83- int64_t round_decimal = 3 ;
79+ TestSubSearchResultMerge (const knowhere::MetricType& metric_type,
80+ const int64_t iteration,
81+ const int64_t nq,
82+ const int64_t topk) {
83+ const int64_t round_decimal = 3 ;
8484
85- std::vector<queue_type> result_ref (num_queries );
85+ std::vector<queue_type> result_ref (nq );
8686
87- SubSearchResult final_result (num_queries , topk, metric_type, round_decimal);
87+ SubSearchResult final_result (nq , topk, metric_type, round_decimal);
8888 for (int i = 0 ; i < iteration; ++i) {
89- SubSearchResultUniq sub_result = GenSubSearchResult (num_queries , topk, metric_type, round_decimal);
89+ SubSearchResultUniq sub_result = GenSubSearchResult (nq , topk, metric_type, round_decimal);
9090 auto ids = sub_result->get_ids ();
91- for (int n = 0 ; n < num_queries ; ++n) {
91+ for (int n = 0 ; n < nq ; ++n) {
9292 for (int k = 0 ; k < topk; ++k) {
9393 int64_t x = ids[n * topk + k];
9494 result_ref[n].push (x);
@@ -99,12 +99,28 @@ TestSubSearchResultMerge(const knowhere::MetricType& metric_type) {
9999 }
100100 final_result.merge (*sub_result);
101101 }
102- CheckSubSearchResult<queue_type>(num_queries , topk, final_result, result_ref);
102+ CheckSubSearchResult<queue_type>(nq , topk, final_result, result_ref);
103103}
104104
105105TEST (Reduce, SubSearchResult) {
106106 using queue_type_l2 = std::priority_queue<int64_t , std::vector<int64_t >, std::less<int64_t >>;
107107 using queue_type_ip = std::priority_queue<int64_t , std::vector<int64_t >, std::greater<int64_t >>;
108- TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2 );
109- TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP );
108+
109+ TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2 , 1 , 1 , 1 );
110+ TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2 , 1 , 1 , 10 );
111+ TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2 , 1 , 16 , 1 );
112+ TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2 , 1 , 16 , 10 );
113+ TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2 , 4 , 1 , 1 );
114+ TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2 , 4 , 1 , 10 );
115+ TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2 , 4 , 16 , 1 );
116+ TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2 , 4 , 16 , 10 );
117+
118+ TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP , 1 , 1 , 1 );
119+ TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP , 1 , 1 , 10 );
120+ TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP , 1 , 16 , 1 );
121+ TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP , 1 , 16 , 10 );
122+ TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP , 4 , 1 , 1 );
123+ TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP , 4 , 1 , 10 );
124+ TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP , 4 , 16 , 1 );
125+ TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP , 4 , 16 , 10 );
110126}
0 commit comments