11"""Streamlit App"""
22
3+ import openai
34import streamlit as st
45from langchain .schema import Document
6+ from streamlit_chat import message
57
68from rabbithole import summarize_document
79from rabbithole .embedding import embed_document
1012from rabbithole .mp3 import SUPPORTED_AV_FILE_TYPES
1113from rabbithole .planner import generate_plan
1214
13- # Global variables
14- global_documents = {}
15- global_embeddings = {}
16- global_keywords = {}
17- global_summaries = {}
15+ # Session variables
16+ for state_var in ["uploaded_files" , "documents" , "embeddings" , "keywords" , "summaries" ]:
17+ if state_var not in st .session_state :
18+ st .session_state [state_var ] = {}
19+ if "plan" not in st .session_state :
20+ st .session_state .plan = None
21+ if "processed" not in st .session_state :
22+ st .session_state .processed = False
23+ if "bot_messages" not in st .session_state :
24+ st .session_state .bot_messages = [
25+ "Hello, I am here to help you learn more efficiently"
26+ ]
27+ if "user_messages" not in st .session_state :
28+ st .session_state .user_messages = []
29+
30+
31+ def generate_response (prompt ):
32+ """Generate a response to a prompt using the GPT-3.5 Turbo model"""
33+ st .session_state ['user_messages' ].append (prompt )
34+
35+ messages = []
36+ # Alternate between bot and the assistant until the conversation is over
37+ msg_idx = 0
38+ while True :
39+ if len (st .session_state ['bot_messages' ]) > msg_idx :
40+ messages .append ({"role" : "assistant" , "content" : st .session_state ['bot_messages' ][msg_idx ]})
41+ else :
42+ break
43+ if len (st .session_state ['user_message' ]) > msg_idx :
44+ messages .append ({"role" : "user" , "content" : st .session_state ['user_messages' ][msg_idx ]})
45+ else :
46+ break
47+ msg_idx += 1
48+ print (messages )
49+ completion = openai .ChatCompletion .create (
50+ model = "gpt-4" ,
51+ messages = messages
52+ )
53+ response = completion .choices [0 ].message .content
54+ st .session_state ['bot_messages' ].append (response )
55+ print (response )
1856
1957
2058def load_files_with_spinner (files : list ) -> dict [str , list [Document ]]:
@@ -80,39 +118,65 @@ def generate_summary_with_spinner(documents: dict[str, list[Document]]) -> dict[
80118def generate_plan_with_spinner () -> dict :
81119 """Generate a logical plan to study the uploaded documents."""
82120 with st .spinner ("Generating plan..." ):
83- plan = generate_plan (global_summaries , global_keywords )
121+ plan = generate_plan (st . session_state . summaries , st . session_state . keywords )
84122 return plan
85123
86124
87125st .set_page_config (page_title = "RabbitHole" , page_icon = "🐇" , layout = "wide" )
88126
89127st .title ("RabbitHole" )
90128
91- uploaded_files = st .file_uploader ("Upload content" ,
92- type = ["docx" , "pdf" , "txt" , * SUPPORTED_IMG_FILE_TYPES , * SUPPORTED_AV_FILE_TYPES ],
93- accept_multiple_files = True )
94-
95- if st .button ("Dive in" ):
96- if not uploaded_files :
97- st .warning ("Please upload a file first." )
98- st .stop ()
99-
100- # Load the text from the uploaded PDF files
101- global_documents = load_files_with_spinner (uploaded_files )
102- global_embeddings = embed_documents_with_spinner (global_documents )
103- global_keywords = extract_keywords_with_spinner (global_embeddings )
104- global_summaries = generate_summary_with_spinner (global_documents )
105-
106- # Display the keywords and summaries
107- for doc_name , doc_keywords in global_keywords .items ():
108- st .header (doc_name )
109- st .caption ("Keywords: " + ", " .join (doc_keywords ))
110- st .write (global_summaries [doc_name ])
111- st .divider ()
129+ if not st .session_state .processed :
130+ uploaded_files = st .file_uploader ("Upload content" ,
131+ type = ["docx" , "pdf" , "txt" , * SUPPORTED_IMG_FILE_TYPES , * SUPPORTED_AV_FILE_TYPES ],
132+ accept_multiple_files = True )
133+
134+ if st .button ("Dive in" ):
135+ if not uploaded_files :
136+ st .warning ("Please upload a file first." )
137+ st .stop ()
138+
139+ # Check if uploaded files have changed
140+ uploaded_files_changed = False
141+ if len (uploaded_files ) != len (st .session_state .uploaded_files ):
142+ uploaded_files_changed = True
143+ else :
144+ for new_file , old_file in zip (uploaded_files , st .session_state .uploaded_files ):
145+ if new_file != old_file :
146+ uploaded_files_changed = True
147+ break
148+
149+ if uploaded_files_changed :
150+ st .session_state .uploaded_files = uploaded_files
151+
152+ # Load the text from the uploaded PDF files
153+ st .session_state .documents = load_files_with_spinner (st .session_state .uploaded_files )
154+ st .session_state .embeddings = embed_documents_with_spinner (st .session_state .documents )
155+ st .session_state .keywords = extract_keywords_with_spinner (st .session_state .embeddings )
156+ st .session_state .summaries = generate_summary_with_spinner (st .session_state .documents )
157+
158+ # Display the keywords and summaries
159+ for doc_name , doc_keywords in st .session_state .keywords .items ():
160+ st .header (doc_name )
161+ st .caption ("Keywords: " + ", " .join (doc_keywords ))
162+ st .write (st .session_state .summaries [doc_name ])
163+ st .divider ()
164+
165+ st .session_state .processed = True
166+ st .success ('Summarization completed.' )
167+
168+ if st .session_state .processed :
169+ st .header ("Loaded Files" )
170+ for file in st .session_state .uploaded_files :
171+ st .write (file .name )
112172
113173 # Display the plan
114174 st .header ("Study Plan" )
115- plan = generate_plan_with_spinner ()
175+ if st .session_state .plan is None :
176+ plan = generate_plan_with_spinner ()
177+ st .session_state .plan = plan
178+ else :
179+ plan = st .session_state .plan
116180 for data in plan .get ("plan" , []):
117181 for doc_name , doc_data in data .items ():
118182 st .subheader (doc_name )
@@ -127,4 +191,22 @@ def generate_plan_with_spinner() -> dict:
127191 st .write (f"- { concept } " )
128192 st .write ("" )
129193
130- st .success ('Summarization completed.' )
194+ st .header ("Chat" )
195+ # Iterate through the bot and user message and print them alternatively
196+ message_i = 0
197+ while True :
198+ if len (st .session_state .bot_messages ) > message_i :
199+ message (st .session_state .bot_messages [message_i ])
200+ else :
201+ break
202+ if len (st .session_state .user_messages ) > message_i :
203+ message (st .session_state .user_messages [message_i ], is_user = True )
204+ else :
205+ break
206+ message_i += 1
207+
208+ user_input = st .text_input ("What do you want to learn more about?" , key = "user_message" )
209+ if st .button ("Send" ):
210+ with st .spinner ("Generating response..." ):
211+ generate_response (user_input )
212+ st .experimental_rerun ()
0 commit comments