@@ -10,15 +10,15 @@ use crate::util::{HeapPointer, IndexPointer, ItemPointer};
1010
1111use super :: graph_neighbor_store:: GraphNeighborStore ;
1212
13+ use super :: neighbor_with_distance:: { Distance , DistanceWithTieBreak } ;
1314use super :: pg_vector:: PgVector ;
1415use super :: stats:: { GreedySearchStats , InsertStats , PruneNeighborStats , StatsNodeVisit } ;
1516use super :: storage:: Storage ;
1617use super :: { meta_page:: MetaPage , neighbor_with_distance:: NeighborWithDistance } ;
1718
1819pub struct ListSearchNeighbor < PD > {
1920 pub index_pointer : IndexPointer ,
20- distance : f32 ,
21- distance_tie_break : usize , /* only used if distance = 0. This ensures a consistent order of results when distance = 0 */
21+ distance_with_tie_break : DistanceWithTieBreak ,
2222 private_data : PD ,
2323}
2424
@@ -38,23 +38,21 @@ impl<PD> Eq for ListSearchNeighbor<PD> {}
3838
3939impl < PD > Ord for ListSearchNeighbor < PD > {
4040 fn cmp ( & self , other : & Self ) -> Ordering {
41- if self . distance == 0.0 && other. distance == 0.0 {
42- /* this logic should be consistent with what's used during pruning */
43- return self . distance_tie_break . cmp ( & other. distance_tie_break ) ;
44- }
45- self . distance . total_cmp ( & other. distance )
41+ self . distance_with_tie_break
42+ . cmp ( & other. distance_with_tie_break )
4643 }
4744}
4845
4946impl < PD > ListSearchNeighbor < PD > {
50- pub fn new ( index_pointer : IndexPointer , distance : f32 , private_data : PD ) -> Self {
51- assert ! ( !distance. is_nan( ) ) ;
52- debug_assert ! ( distance >= 0.0 ) ;
47+ pub fn new (
48+ index_pointer : IndexPointer ,
49+ distance_with_tie_break : DistanceWithTieBreak ,
50+ private_data : PD ,
51+ ) -> Self {
5352 Self {
5453 index_pointer,
5554 private_data,
56- distance,
57- distance_tie_break : 0 ,
55+ distance_with_tie_break,
5856 }
5957 }
6058
@@ -116,16 +114,22 @@ impl<QDM, PD> ListSearchResult<QDM, PD> {
116114 self . inserted . insert ( ip)
117115 }
118116
119- /// Internal function
120- pub fn insert_neighbor ( & mut self , mut n : ListSearchNeighbor < PD > ) {
121- self . stats . record_candidate ( ) ;
122- if n. distance == 0.0 {
123- /* record the tie break if distance is 0 */
124- if let Some ( tie_break_item_pointer) = self . tie_break_item_pointer {
125- let d = tie_break_item_pointer. ip_distance ( n. index_pointer ) ;
126- n. distance_tie_break = d;
117+ pub fn create_distance_with_tie_break (
118+ & self ,
119+ d : Distance ,
120+ ip : ItemPointer ,
121+ ) -> DistanceWithTieBreak {
122+ match self . tie_break_item_pointer {
123+ None => DistanceWithTieBreak :: with_query ( d, ip) ,
124+ Some ( tie_break_item_pointer) => {
125+ DistanceWithTieBreak :: new ( d, ip, tie_break_item_pointer)
127126 }
128127 }
128+ }
129+
130+ /// Internal function
131+ pub fn insert_neighbor ( & mut self , n : ListSearchNeighbor < PD > ) {
132+ self . stats . record_candidate ( ) ;
129133 self . candidates . push ( Reverse ( n) ) ;
130134 }
131135
@@ -332,8 +336,7 @@ impl<'a> Graph<'a> {
332336 let list_search_entry = & lsr. visited [ list_search_entry_idx] ;
333337 visited_nodes. insert ( NeighborWithDistance :: new (
334338 list_search_entry. index_pointer ,
335- list_search_entry. distance ,
336- list_search_entry. distance_tie_break ,
339+ list_search_entry. distance_with_tie_break . clone ( ) ,
337340 ) ) ;
338341 }
339342 }
@@ -349,7 +352,7 @@ impl<'a> Graph<'a> {
349352 /// if we save the factors or the distances and add incrementally. Not sure.
350353 pub fn prune_neighbors < S : Storage > (
351354 & self ,
352- neighbors_of : ItemPointer ,
355+ _neighbors_of : ItemPointer ,
353356 mut candidates : Vec < NeighborWithDistance > ,
354357 storage : & S ,
355358 stats : & mut PruneNeighborStats ,
@@ -410,8 +413,9 @@ impl<'a> Graph<'a> {
410413 dist_state
411414 . get_distance ( candidate_neighbor. get_index_pointer_to_neighbor ( ) , stats)
412415 } ;
413- let mut distance_between_candidate_and_point =
414- candidate_neighbor. get_distance ( ) ;
416+ let mut distance_between_candidate_and_point = candidate_neighbor
417+ . get_distance_with_tie_break ( )
418+ . get_distance ( ) ;
415419
416420 //We need both values to be positive.
417421 //Otherwise, the case where distance_between_candidate_and_point > 0 and distance_between_candidate_and_existing_neighbor < 0 is totally wrong.
@@ -451,8 +455,8 @@ impl<'a> Graph<'a> {
451455
452456 Note: with sbq these equivalence relations are actually not uncommon */
453457 let ip_distance_between_candidate_and_point = candidate_neighbor
454- . get_index_pointer_to_neighbor ( )
455- . ip_distance ( neighbors_of ) ;
458+ . get_distance_with_tie_break ( )
459+ . get_distance_tie_break ( ) ;
456460
457461 let ip_distance_between_candidate_and_existing_neighbor =
458462 candidate_neighbor
@@ -527,8 +531,7 @@ impl<'a> Graph<'a> {
527531 let ( _needed_prune, contains) = self . update_back_pointer (
528532 neighbor. get_index_pointer_to_neighbor ( ) ,
529533 index_pointer,
530- neighbor. get_distance ( ) ,
531- neighbor. get_distance_tie_break ( ) ,
534+ neighbor. get_distance_with_tie_break ( ) ,
532535 storage,
533536 & mut stats. prune_neighbor_stats ,
534537 ) ;
@@ -545,12 +548,14 @@ impl<'a> Graph<'a> {
545548 & mut self ,
546549 from : IndexPointer ,
547550 to : IndexPointer ,
548- distance : f32 ,
549- distance_tie_break : usize ,
551+ distance_with_tie_break : & DistanceWithTieBreak ,
550552 storage : & S ,
551553 prune_stats : & mut PruneNeighborStats ,
552554 ) -> ( bool , bool ) {
553- let new = vec ! [ NeighborWithDistance :: new( to, distance, distance_tie_break) ] ;
555+ let new = vec ! [ NeighborWithDistance :: new(
556+ to,
557+ distance_with_tie_break. clone( ) ,
558+ ) ] ;
554559 let ( pruned, n) = self . add_neighbors ( storage, from, new. clone ( ) , prune_stats) ;
555560 ( pruned, n. contains ( & new[ 0 ] ) )
556561 }
0 commit comments