1010from rabbithole .mp3 import SUPPORTED_AV_FILE_TYPES
1111from rabbithole .planner import generate_plan
1212
13- # Global variables
14- global_documents = {}
15- global_embeddings = {}
16- global_keywords = {}
17- global_summaries = {}
13+ # Session variables
14+ for state_var in ["uploaded_files" , "documents" , "embeddings" , "keywords" , "summaries" ]:
15+ if state_var not in st .session_state :
16+ st .session_state [state_var ] = {}
1817
1918
2019def load_files_with_spinner (files : list ) -> dict [str , list [Document ]]:
@@ -80,7 +79,7 @@ def generate_summary_with_spinner(documents: dict[str, list[Document]]) -> dict[
8079def generate_plan_with_spinner () -> dict :
8180 """Generate a logical plan to study the uploaded documents."""
8281 with st .spinner ("Generating plan..." ):
83- plan = generate_plan (global_summaries , global_keywords )
82+ plan = generate_plan (st . session_state . summaries , st . session_state . keywords )
8483 return plan
8584
8685
@@ -97,17 +96,30 @@ def generate_plan_with_spinner() -> dict:
9796 st .warning ("Please upload a file first." )
9897 st .stop ()
9998
100- # Load the text from the uploaded PDF files
101- global_documents = load_files_with_spinner (uploaded_files )
102- global_embeddings = embed_documents_with_spinner (global_documents )
103- global_keywords = extract_keywords_with_spinner (global_embeddings )
104- global_summaries = generate_summary_with_spinner (global_documents )
99+ # Check if uploaded files have changed
100+ uploaded_files_changed = False
101+ if len (uploaded_files ) != len (st .session_state .uploaded_files ):
102+ uploaded_files_changed = True
103+ else :
104+ for new_file , old_file in zip (uploaded_files , st .session_state .uploaded_files ):
105+ if new_file != old_file :
106+ uploaded_files_changed = True
107+ break
108+
109+ if uploaded_files_changed :
110+ st .session_state .uploaded_files = uploaded_files
111+
112+ # Load the text from the uploaded PDF files
113+ st .session_state .documents = load_files_with_spinner (st .session_state .uploaded_files )
114+ st .session_state .embeddings = embed_documents_with_spinner (st .session_state .documents )
115+ st .session_state .keywords = extract_keywords_with_spinner (st .session_state .embeddings )
116+ st .session_state .summaries = generate_summary_with_spinner (st .session_state .documents )
105117
106118 # Display the keywords and summaries
107- for doc_name , doc_keywords in global_keywords .items ():
119+ for doc_name , doc_keywords in st . session_state . keywords .items ():
108120 st .header (doc_name )
109121 st .caption ("Keywords: " + ", " .join (doc_keywords ))
110- st .write (global_summaries [doc_name ])
122+ st .write (st . session_state . summaries [doc_name ])
111123 st .divider ()
112124
113125 # Display the plan
0 commit comments