11"""rabbithole.wikipedia module"""
22
3+ import os
4+
35from chromadb .api import Collection
46from chromadb .errors import ChromaError
7+ from chromadb .utils import embedding_functions
58from datasets import load_dataset
69from tqdm import tqdm
710
811from rabbithole .vecstore import client
912
13+ cohere_ef = embedding_functions .CohereEmbeddingFunction (
14+ api_key = os .getenv ("COHERE_API_KEY" ),
15+ model_name = "multilingual-22-12" ,
16+ )
17+
1018
1119def get_wikipedia_collection () -> Collection :
1220 """
1321 Get the wikipedia collection
1422 :return: The wikipedia collection
1523 """
1624 try :
17- return client .get_collection ("wikipedia" )
25+ return client .get_collection ("wikipedia" , embedding_function = cohere_ef )
1826 except (ValueError , ChromaError ):
1927 return prepare_wikipedia_collection ()
2028
@@ -35,7 +43,7 @@ def prepare_wikipedia_collection(batch_size: int = 10000) -> Collection:
3543 print ("Wikipedia collection already exists. Deleting and recreating..." )
3644 client .delete_collection ("wikipedia" )
3745 break
38- collection = client .create_collection ("wikipedia" )
46+ collection = client .create_collection ("wikipedia" , embedding_function = cohere_ef )
3947
4048 total_rows = len (wikipedia_dataset )
4149 with tqdm (total = total_rows , desc = 'Processing batches' , unit = 'vectors' ) as pbar :
0 commit comments