Skip to content

Commit 460de91

Browse files
committed
chore: refactor to add proper DistanceWithTieBreak struct.
Previously this logic was all over the place. Now consolidated.
1 parent 9046adb commit 460de91

5 files changed

Lines changed: 136 additions & 69 deletions

File tree

pgvectorscale/src/access_method/build.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -864,11 +864,7 @@ pub mod tests {
864864
Ok(())
865865
}
866866

867-
pub fn verify_index_accuracy(
868-
index_options: &str,
869-
expected_cnt: i64,
870-
dimensions: usize,
871-
) -> spi::Result<()> {
867+
pub fn verify_index_accuracy(expected_cnt: i64, dimensions: usize) -> spi::Result<()> {
872868
let test_vec: Option<Vec<f32>> = Spi::get_one(&format!(
873869
"SELECT('{{' || array_to_string(array_agg(1.0), ',', '0') || '}}')::real[] AS embedding
874870
FROM generate_series(1, {dimensions})"
@@ -966,7 +962,8 @@ pub mod tests {
966962
('[' || array_to_string(array_agg(random()), ',', '0') || ']')::vector AS embedding
967963
FROM generate_series(1, {dimensions}));"))?;
968964

969-
verify_index_accuracy(index_options, expected_cnt, dimensions)
965+
verify_index_accuracy(expected_cnt, dimensions)?;
966+
Ok(())
970967
}
971968

972969
#[pg_test]
@@ -1012,7 +1009,7 @@ pub mod tests {
10121009
('[' || array_to_string(array_agg(random()), ',', '0') || ']')::vector AS embedding
10131010
FROM generate_series(1, {dimensions}));"))?;
10141011

1015-
verify_index_accuracy(index_options, expected_cnt, dimensions)?;
1012+
verify_index_accuracy(expected_cnt, dimensions)?;
10161013
Ok(())
10171014
}
10181015
}

pgvectorscale/src/access_method/graph.rs

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ use crate::util::{HeapPointer, IndexPointer, ItemPointer};
1010

1111
use super::graph_neighbor_store::GraphNeighborStore;
1212

13+
use super::neighbor_with_distance::{Distance, DistanceWithTieBreak};
1314
use super::pg_vector::PgVector;
1415
use super::stats::{GreedySearchStats, InsertStats, PruneNeighborStats, StatsNodeVisit};
1516
use super::storage::Storage;
1617
use super::{meta_page::MetaPage, neighbor_with_distance::NeighborWithDistance};
1718

1819
pub struct ListSearchNeighbor<PD> {
1920
pub index_pointer: IndexPointer,
20-
distance: f32,
21-
distance_tie_break: usize, /* only used if distance = 0. This ensures a consistent order of results when distance = 0 */
21+
distance_with_tie_break: DistanceWithTieBreak,
2222
private_data: PD,
2323
}
2424

@@ -38,23 +38,21 @@ impl<PD> Eq for ListSearchNeighbor<PD> {}
3838

3939
impl<PD> Ord for ListSearchNeighbor<PD> {
4040
fn cmp(&self, other: &Self) -> Ordering {
41-
if self.distance == 0.0 && other.distance == 0.0 {
42-
/* this logic should be consistent with what's used during pruning */
43-
return self.distance_tie_break.cmp(&other.distance_tie_break);
44-
}
45-
self.distance.total_cmp(&other.distance)
41+
self.distance_with_tie_break
42+
.cmp(&other.distance_with_tie_break)
4643
}
4744
}
4845

4946
impl<PD> ListSearchNeighbor<PD> {
50-
pub fn new(index_pointer: IndexPointer, distance: f32, private_data: PD) -> Self {
51-
assert!(!distance.is_nan());
52-
debug_assert!(distance >= 0.0);
47+
pub fn new(
48+
index_pointer: IndexPointer,
49+
distance_with_tie_break: DistanceWithTieBreak,
50+
private_data: PD,
51+
) -> Self {
5352
Self {
5453
index_pointer,
5554
private_data,
56-
distance,
57-
distance_tie_break: 0,
55+
distance_with_tie_break,
5856
}
5957
}
6058

@@ -116,16 +114,22 @@ impl<QDM, PD> ListSearchResult<QDM, PD> {
116114
self.inserted.insert(ip)
117115
}
118116

119-
/// Internal function
120-
pub fn insert_neighbor(&mut self, mut n: ListSearchNeighbor<PD>) {
121-
self.stats.record_candidate();
122-
if n.distance == 0.0 {
123-
/* record the tie break if distance is 0 */
124-
if let Some(tie_break_item_pointer) = self.tie_break_item_pointer {
125-
let d = tie_break_item_pointer.ip_distance(n.index_pointer);
126-
n.distance_tie_break = d;
117+
pub fn create_distance_with_tie_break(
118+
&self,
119+
d: Distance,
120+
ip: ItemPointer,
121+
) -> DistanceWithTieBreak {
122+
match self.tie_break_item_pointer {
123+
None => DistanceWithTieBreak::with_query(d, ip),
124+
Some(tie_break_item_pointer) => {
125+
DistanceWithTieBreak::new(d, ip, tie_break_item_pointer)
127126
}
128127
}
128+
}
129+
130+
/// Internal function
131+
pub fn insert_neighbor(&mut self, n: ListSearchNeighbor<PD>) {
132+
self.stats.record_candidate();
129133
self.candidates.push(Reverse(n));
130134
}
131135

@@ -332,8 +336,7 @@ impl<'a> Graph<'a> {
332336
let list_search_entry = &lsr.visited[list_search_entry_idx];
333337
visited_nodes.insert(NeighborWithDistance::new(
334338
list_search_entry.index_pointer,
335-
list_search_entry.distance,
336-
list_search_entry.distance_tie_break,
339+
list_search_entry.distance_with_tie_break.clone(),
337340
));
338341
}
339342
}
@@ -349,7 +352,7 @@ impl<'a> Graph<'a> {
349352
/// if we save the factors or the distances and add incrementally. Not sure.
350353
pub fn prune_neighbors<S: Storage>(
351354
&self,
352-
neighbors_of: ItemPointer,
355+
_neighbors_of: ItemPointer,
353356
mut candidates: Vec<NeighborWithDistance>,
354357
storage: &S,
355358
stats: &mut PruneNeighborStats,
@@ -410,8 +413,9 @@ impl<'a> Graph<'a> {
410413
dist_state
411414
.get_distance(candidate_neighbor.get_index_pointer_to_neighbor(), stats)
412415
};
413-
let mut distance_between_candidate_and_point =
414-
candidate_neighbor.get_distance();
416+
let mut distance_between_candidate_and_point = candidate_neighbor
417+
.get_distance_with_tie_break()
418+
.get_distance();
415419

416420
//We need both values to be positive.
417421
//Otherwise, the case where distance_between_candidate_and_point > 0 and distance_between_candidate_and_existing_neighbor < 0 is totally wrong.
@@ -451,8 +455,8 @@ impl<'a> Graph<'a> {
451455
452456
Note: with sbq these equivalence relations are actually not uncommon */
453457
let ip_distance_between_candidate_and_point = candidate_neighbor
454-
.get_index_pointer_to_neighbor()
455-
.ip_distance(neighbors_of);
458+
.get_distance_with_tie_break()
459+
.get_distance_tie_break();
456460

457461
let ip_distance_between_candidate_and_existing_neighbor =
458462
candidate_neighbor
@@ -527,8 +531,7 @@ impl<'a> Graph<'a> {
527531
let (_needed_prune, contains) = self.update_back_pointer(
528532
neighbor.get_index_pointer_to_neighbor(),
529533
index_pointer,
530-
neighbor.get_distance(),
531-
neighbor.get_distance_tie_break(),
534+
neighbor.get_distance_with_tie_break(),
532535
storage,
533536
&mut stats.prune_neighbor_stats,
534537
);
@@ -545,12 +548,14 @@ impl<'a> Graph<'a> {
545548
&mut self,
546549
from: IndexPointer,
547550
to: IndexPointer,
548-
distance: f32,
549-
distance_tie_break: usize,
551+
distance_with_tie_break: &DistanceWithTieBreak,
550552
storage: &S,
551553
prune_stats: &mut PruneNeighborStats,
552554
) -> (bool, bool) {
553-
let new = vec![NeighborWithDistance::new(to, distance, distance_tie_break)];
555+
let new = vec![NeighborWithDistance::new(
556+
to,
557+
distance_with_tie_break.clone(),
558+
)];
554559
let (pruned, n) = self.add_neighbors(storage, from, new.clone(), prune_stats);
555560
(pruned, n.contains(&new[0]))
556561
}

pgvectorscale/src/access_method/neighbor_with_distance.rs

Lines changed: 82 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,104 @@
1-
use std::cmp::Ordering;
1+
use std::{cell::OnceCell, cmp::Ordering};
22

33
use crate::util::{IndexPointer, ItemPointer};
44

55
//TODO is this right?
66
pub type Distance = f32;
7+
8+
/* implements a distance with a lazy tie break */
79
#[derive(Clone, Debug)]
8-
pub struct NeighborWithDistance {
9-
index_pointer: IndexPointer,
10+
pub struct DistanceWithTieBreak {
1011
distance: Distance,
11-
distance_tie_break: usize,
12+
from: IndexPointer,
13+
to: IndexPointer,
14+
_distance_tie_break: OnceCell<usize>,
1215
}
1316

14-
impl NeighborWithDistance {
15-
pub fn new(
16-
neighbor_index_pointer: ItemPointer,
17-
distance: Distance,
18-
distance_tie_break: usize,
19-
) -> Self {
17+
impl DistanceWithTieBreak {
18+
pub fn new(distance: Distance, from: IndexPointer, to: IndexPointer) -> Self {
2019
assert!(!distance.is_nan());
2120
assert!(distance >= 0.0);
21+
DistanceWithTieBreak {
22+
distance,
23+
from,
24+
to,
25+
_distance_tie_break: OnceCell::new(),
26+
}
27+
}
28+
29+
pub fn with_query(distance: Distance, to: IndexPointer) -> Self {
30+
//this is the distance from the query to a index node.
31+
//make the distance_tie_break = 0
32+
let distance_tie_break = OnceCell::new();
33+
distance_tie_break.set(0).unwrap();
34+
DistanceWithTieBreak {
35+
distance,
36+
from: to,
37+
to,
38+
_distance_tie_break: distance_tie_break,
39+
}
40+
}
41+
42+
pub fn get_distance_tie_break(&self) -> usize {
43+
*self
44+
._distance_tie_break
45+
.get_or_init(|| self.from.ip_distance(self.to))
46+
}
47+
48+
pub fn get_distance(&self) -> Distance {
49+
self.distance
50+
}
51+
}
52+
53+
impl PartialOrd for DistanceWithTieBreak {
54+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
55+
if self.distance == 0.0 && other.distance == 0.0 {
56+
return self
57+
.get_distance_tie_break()
58+
.partial_cmp(&other.get_distance_tie_break());
59+
}
60+
self.distance.partial_cmp(&other.distance)
61+
}
62+
}
63+
64+
impl Ord for DistanceWithTieBreak {
65+
fn cmp(&self, other: &Self) -> Ordering {
66+
self.partial_cmp(other).unwrap()
67+
}
68+
}
69+
70+
impl PartialEq for DistanceWithTieBreak {
71+
fn eq(&self, other: &Self) -> bool {
72+
if self.distance == 0.0 && other.distance == 0.0 {
73+
return self.get_distance_tie_break() == other.get_distance_tie_break();
74+
}
75+
self.distance == other.distance
76+
}
77+
}
78+
79+
//promise that PartialEq is reflexive
80+
impl Eq for DistanceWithTieBreak {}
81+
82+
#[derive(Clone, Debug)]
83+
pub struct NeighborWithDistance {
84+
index_pointer: IndexPointer,
85+
distance: DistanceWithTieBreak,
86+
}
87+
88+
impl NeighborWithDistance {
89+
pub fn new(neighbor_index_pointer: ItemPointer, distance: DistanceWithTieBreak) -> Self {
2290
Self {
2391
index_pointer: neighbor_index_pointer,
2492
distance,
25-
distance_tie_break,
2693
}
2794
}
2895

2996
pub fn get_index_pointer_to_neighbor(&self) -> ItemPointer {
3097
self.index_pointer
3198
}
32-
pub fn get_distance(&self) -> Distance {
33-
self.distance
34-
}
35-
pub fn get_distance_tie_break(&self) -> usize {
36-
return self.distance_tie_break;
99+
100+
pub fn get_distance_with_tie_break(&self) -> &DistanceWithTieBreak {
101+
&self.distance
37102
}
38103
}
39104

@@ -45,10 +110,7 @@ impl PartialOrd for NeighborWithDistance {
45110

46111
impl Ord for NeighborWithDistance {
47112
fn cmp(&self, other: &Self) -> Ordering {
48-
if self.distance == 0.0 && other.distance == 0.0 {
49-
return self.distance_tie_break.cmp(&other.distance_tie_break);
50-
}
51-
self.distance.total_cmp(&other.distance)
113+
self.distance.cmp(&other.distance)
52114
}
53115
}
54116

pgvectorscale/src/access_method/plain_storage.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use super::{
22
distance::DistanceFn,
33
graph::{ListSearchNeighbor, ListSearchResult},
44
graph_neighbor_store::GraphNeighborStore,
5+
neighbor_with_distance::DistanceWithTieBreak,
56
pg_vector::PgVector,
67
plain_node::{ArchivedNode, Node, ReadableNode},
78
stats::{
@@ -256,8 +257,7 @@ impl<'a> Storage for PlainStorage<'a> {
256257
let dist = unsafe { dist_state.get_distance(n, stats) };
257258
result.push(NeighborWithDistance::new(
258259
n,
259-
dist,
260-
n.ip_distance(neighbors_of),
260+
DistanceWithTieBreak::new(dist, neighbors_of, n),
261261
))
262262
}
263263
}
@@ -288,7 +288,7 @@ impl<'a> Storage for PlainStorage<'a> {
288288

289289
ListSearchNeighbor::new(
290290
index_pointer,
291-
distance,
291+
lsr.create_distance_with_tie_break(distance, index_pointer),
292292
PlainStorageLsnPrivateData::new(index_pointer, node, gns),
293293
)
294294
}
@@ -322,7 +322,7 @@ impl<'a> Storage for PlainStorage<'a> {
322322
};
323323
let lsn = ListSearchNeighbor::new(
324324
neighbor_index_pointer,
325-
distance,
325+
lsr.create_distance_with_tie_break(distance, neighbor_index_pointer),
326326
PlainStorageLsnPrivateData::new(neighbor_index_pointer, node_neighbor, gns),
327327
);
328328

0 commit comments

Comments
 (0)