Skip to content

Commit e28444b

Browse files
committed
Add cohere_ef
1 parent fb5d793 commit e28444b

1 file changed

Lines changed: 10 additions & 2 deletions

File tree

rabbithole/wikipedia.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,28 @@
11
"""rabbithole.wikipedia module"""
22

3+
import os
4+
35
from chromadb.api import Collection
46
from chromadb.errors import ChromaError
7+
from chromadb.utils import embedding_functions
58
from datasets import load_dataset
69
from tqdm import tqdm
710

811
from rabbithole.vecstore import client
912

13+
cohere_ef = embedding_functions.CohereEmbeddingFunction(
14+
api_key=os.getenv("COHERE_API_KEY"),
15+
model_name="multilingual-22-12",
16+
)
17+
1018

1119
def get_wikipedia_collection() -> Collection:
1220
"""
1321
Get the wikipedia collection
1422
:return: The wikipedia collection
1523
"""
1624
try:
17-
return client.get_collection("wikipedia")
25+
return client.get_collection("wikipedia", embedding_function=cohere_ef)
1826
except (ValueError, ChromaError):
1927
return prepare_wikipedia_collection()
2028

@@ -35,7 +43,7 @@ def prepare_wikipedia_collection(batch_size: int = 10000) -> Collection:
3543
print("Wikipedia collection already exists. Deleting and recreating...")
3644
client.delete_collection("wikipedia")
3745
break
38-
collection = client.create_collection("wikipedia")
46+
collection = client.create_collection("wikipedia", embedding_function=cohere_ef)
3947

4048
total_rows = len(wikipedia_dataset)
4149
with tqdm(total=total_rows, desc='Processing batches', unit='vectors') as pbar:

0 commit comments

Comments
 (0)