Skip to content

Commit 17b88b8

Browse files
committed
Add basic st.session_state
1 parent 20ef688 commit 17b88b8

1 file changed

Lines changed: 25 additions & 13 deletions

File tree

app.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010
from rabbithole.mp3 import SUPPORTED_AV_FILE_TYPES
1111
from rabbithole.planner import generate_plan
1212

13-
# Global variables
14-
global_documents = {}
15-
global_embeddings = {}
16-
global_keywords = {}
17-
global_summaries = {}
13+
# Session variables
14+
for state_var in ["uploaded_files", "documents", "embeddings", "keywords", "summaries"]:
15+
if state_var not in st.session_state:
16+
st.session_state[state_var] = {}
1817

1918

2019
def load_files_with_spinner(files: list) -> dict[str, list[Document]]:
@@ -80,7 +79,7 @@ def generate_summary_with_spinner(documents: dict[str, list[Document]]) -> dict[
8079
def generate_plan_with_spinner() -> dict:
8180
"""Generate a logical plan to study the uploaded documents."""
8281
with st.spinner("Generating plan..."):
83-
plan = generate_plan(global_summaries, global_keywords)
82+
plan = generate_plan(st.session_state.summaries, st.session_state.keywords)
8483
return plan
8584

8685

@@ -97,17 +96,30 @@ def generate_plan_with_spinner() -> dict:
9796
st.warning("Please upload a file first.")
9897
st.stop()
9998

100-
# Load the text from the uploaded PDF files
101-
global_documents = load_files_with_spinner(uploaded_files)
102-
global_embeddings = embed_documents_with_spinner(global_documents)
103-
global_keywords = extract_keywords_with_spinner(global_embeddings)
104-
global_summaries = generate_summary_with_spinner(global_documents)
99+
# Check if uploaded files have changed
100+
uploaded_files_changed = False
101+
if len(uploaded_files) != len(st.session_state.uploaded_files):
102+
uploaded_files_changed = True
103+
else:
104+
for new_file, old_file in zip(uploaded_files, st.session_state.uploaded_files):
105+
if new_file != old_file:
106+
uploaded_files_changed = True
107+
break
108+
109+
if uploaded_files_changed:
110+
st.session_state.uploaded_files = uploaded_files
111+
112+
# Load the text from the uploaded PDF files
113+
st.session_state.documents = load_files_with_spinner(st.session_state.uploaded_files)
114+
st.session_state.embeddings = embed_documents_with_spinner(st.session_state.documents)
115+
st.session_state.keywords = extract_keywords_with_spinner(st.session_state.embeddings)
116+
st.session_state.summaries = generate_summary_with_spinner(st.session_state.documents)
105117

106118
# Display the keywords and summaries
107-
for doc_name, doc_keywords in global_keywords.items():
119+
for doc_name, doc_keywords in st.session_state.keywords.items():
108120
st.header(doc_name)
109121
st.caption("Keywords: " + ", ".join(doc_keywords))
110-
st.write(global_summaries[doc_name])
122+
st.write(st.session_state.summaries[doc_name])
111123
st.divider()
112124

113125
# Display the plan

0 commit comments

Comments
 (0)