Skip to content

Commit 163a734

Browse files
committed
Update prepare_wikipedia_collection() to prepare in batches
1 parent 4bdf41e commit 163a734

1 file changed

Lines changed: 13 additions & 10 deletions

File tree

rabbithole/wikipedia.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,29 @@
88
wikipedia_collection = client.create_collection("wikipedia")
99

1010

11-
def prepare_wiki_collection():
11+
def prepare_wikipedia_collection(batch_size: int = 10000):
1212
"""
1313
Prepare the wikipedia collection
1414
1515
NOTE: Only needs to be run once to prepare the collection for the first time
1616
"""
17-
ids, embeddings, documents, metadatas = zip(
18-
*[(i, row['emb'], row['text'], row['title']) for i, row in enumerate(wikipedia_dataset)])
17+
total_rows = len(wikipedia_dataset)
18+
for i in range(0, total_rows, batch_size):
19+
batch_data = wikipedia_dataset[i: i + batch_size]
20+
ids, embeddings, documents, metadatas = zip(
21+
*[(i, row['emb'], row['text'], row['title']) for i, row in enumerate(batch_data)])
1922

20-
wikipedia_collection.add(
21-
ids=ids,
22-
embeddings=embeddings,
23-
documents=documents,
24-
metadatas=metadatas,
25-
)
23+
wikipedia_collection.add(
24+
ids=ids,
25+
embeddings=embeddings,
26+
documents=documents,
27+
metadatas=metadatas,
28+
)
2629

2730

2831
if __name__ == "__main__":
2932
print(wikipedia_dataset.info)
3033

3134
print("Preparing wikipedia collection...")
32-
prepare_wiki_collection()
35+
prepare_wikipedia_collection()
3336
print("Done!")

0 commit comments

Comments
 (0)