Skip to content

Commit f5b32d7

Browse files
committed
Add support for chat
1 parent 17b88b8 commit f5b32d7

1 file changed

Lines changed: 116 additions & 53 deletions

File tree

app.py

Lines changed: 116 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Streamlit App"""
22

3+
import openai
34
import streamlit as st
45
from langchain.schema import Document
6+
from streamlit_chat import message
57

68
from rabbithole import summarize_document
79
from rabbithole.embedding import embed_document
@@ -14,6 +16,41 @@
1416
for state_var in ["uploaded_files", "documents", "embeddings", "keywords", "summaries"]:
1517
if state_var not in st.session_state:
1618
st.session_state[state_var] = {}
19+
if "processed" not in st.session_state:
20+
st.session_state.processed = False
21+
if "bot_messages" not in st.session_state:
22+
st.session_state.bot_messages = [
23+
"Hello, I am here to help you learn more efficiently"
24+
]
25+
if "user_messages" not in st.session_state:
26+
st.session_state.user_messages = []
27+
28+
29+
def generate_response(prompt):
30+
"""Generate a response to a prompt using the GPT-3.5 Turbo model"""
31+
st.session_state['user_messages'].append(prompt)
32+
33+
messages = []
34+
# Alternate between bot and the assistant until the conversation is over
35+
msg_idx = 0
36+
while True:
37+
if len(st.session_state['bot_messages']) > msg_idx:
38+
messages.append({"role": "assistant", "content": st.session_state['bot_messages'][msg_idx]})
39+
else:
40+
break
41+
if len(st.session_state['user_message']) > msg_idx:
42+
messages.append({"role": "user", "content": st.session_state['user_messages'][msg_idx]})
43+
else:
44+
break
45+
msg_idx += 1
46+
print(messages)
47+
completion = openai.ChatCompletion.create(
48+
model="gpt-4",
49+
messages=messages
50+
)
51+
response = completion.choices[0].message.content
52+
st.session_state['bot_messages'].append(response)
53+
print(response)
1754

1855

1956
def load_files_with_spinner(files: list) -> dict[str, list[Document]]:
@@ -87,56 +124,82 @@ def generate_plan_with_spinner() -> dict:
87124

88125
st.title("RabbitHole")
89126

90-
uploaded_files = st.file_uploader("Upload content",
91-
type=["docx", "pdf", "txt", *SUPPORTED_IMG_FILE_TYPES, *SUPPORTED_AV_FILE_TYPES],
92-
accept_multiple_files=True)
93-
94-
if st.button("Dive in"):
95-
if not uploaded_files:
96-
st.warning("Please upload a file first.")
97-
st.stop()
98-
99-
# Check if uploaded files have changed
100-
uploaded_files_changed = False
101-
if len(uploaded_files) != len(st.session_state.uploaded_files):
102-
uploaded_files_changed = True
103-
else:
104-
for new_file, old_file in zip(uploaded_files, st.session_state.uploaded_files):
105-
if new_file != old_file:
106-
uploaded_files_changed = True
107-
break
108-
109-
if uploaded_files_changed:
110-
st.session_state.uploaded_files = uploaded_files
111-
112-
# Load the text from the uploaded PDF files
113-
st.session_state.documents = load_files_with_spinner(st.session_state.uploaded_files)
114-
st.session_state.embeddings = embed_documents_with_spinner(st.session_state.documents)
115-
st.session_state.keywords = extract_keywords_with_spinner(st.session_state.embeddings)
116-
st.session_state.summaries = generate_summary_with_spinner(st.session_state.documents)
117-
118-
# Display the keywords and summaries
119-
for doc_name, doc_keywords in st.session_state.keywords.items():
120-
st.header(doc_name)
121-
st.caption("Keywords: " + ", ".join(doc_keywords))
122-
st.write(st.session_state.summaries[doc_name])
123-
st.divider()
124-
125-
# Display the plan
126-
st.header("Study Plan")
127-
plan = generate_plan_with_spinner()
128-
for data in plan.get("plan", []):
129-
for doc_name, doc_data in data.items():
130-
st.subheader(doc_name)
131-
st.write(f"**Background Concepts**")
132-
for concept in doc_data.get("Background Concepts", []):
133-
st.write(f"- {concept}")
134-
st.write(f"**Key Concepts**")
135-
for concept in doc_data.get("Key Concepts", []):
136-
st.write(f"- {concept}")
137-
st.write(f"**Further Reading**")
138-
for concept in doc_data.get("Further Reading", []):
139-
st.write(f"- {concept}")
140-
st.write("")
141-
142-
st.success('Summarization completed.')
127+
if not st.session_state.processed:
128+
uploaded_files = st.file_uploader("Upload content",
129+
type=["docx", "pdf", "txt", *SUPPORTED_IMG_FILE_TYPES, *SUPPORTED_AV_FILE_TYPES],
130+
accept_multiple_files=True)
131+
132+
if st.button("Dive in"):
133+
if not uploaded_files:
134+
st.warning("Please upload a file first.")
135+
st.stop()
136+
137+
# Check if uploaded files have changed
138+
uploaded_files_changed = False
139+
if len(uploaded_files) != len(st.session_state.uploaded_files):
140+
uploaded_files_changed = True
141+
else:
142+
for new_file, old_file in zip(uploaded_files, st.session_state.uploaded_files):
143+
if new_file != old_file:
144+
uploaded_files_changed = True
145+
break
146+
147+
if uploaded_files_changed:
148+
st.session_state.uploaded_files = uploaded_files
149+
150+
# Load the text from the uploaded PDF files
151+
st.session_state.documents = load_files_with_spinner(st.session_state.uploaded_files)
152+
st.session_state.embeddings = embed_documents_with_spinner(st.session_state.documents)
153+
st.session_state.keywords = extract_keywords_with_spinner(st.session_state.embeddings)
154+
st.session_state.summaries = generate_summary_with_spinner(st.session_state.documents)
155+
156+
# Display the keywords and summaries
157+
for doc_name, doc_keywords in st.session_state.keywords.items():
158+
st.header(doc_name)
159+
st.caption("Keywords: " + ", ".join(doc_keywords))
160+
st.write(st.session_state.summaries[doc_name])
161+
st.divider()
162+
163+
# Display the plan
164+
st.header("Study Plan")
165+
plan = generate_plan_with_spinner()
166+
for data in plan.get("plan", []):
167+
for doc_name, doc_data in data.items():
168+
st.subheader(doc_name)
169+
st.write(f"**Background Concepts**")
170+
for concept in doc_data.get("Background Concepts", []):
171+
st.write(f"- {concept}")
172+
st.write(f"**Key Concepts**")
173+
for concept in doc_data.get("Key Concepts", []):
174+
st.write(f"- {concept}")
175+
st.write(f"**Further Reading**")
176+
for concept in doc_data.get("Further Reading", []):
177+
st.write(f"- {concept}")
178+
st.write("")
179+
180+
st.session_state.processed = True
181+
st.success('Summarization completed.')
182+
183+
if st.session_state.processed:
184+
st.header("Loaded Files")
185+
for file in st.session_state.uploaded_files:
186+
st.write(file.name)
187+
188+
st.header("Chat")
189+
# Iterate through the bot and user message and print them alternatively
190+
message_i = 0
191+
while True:
192+
if len(st.session_state.bot_messages) > message_i:
193+
message(st.session_state.bot_messages[message_i])
194+
else:
195+
break
196+
if len(st.session_state.user_messages) > message_i:
197+
message(st.session_state.user_messages[message_i], is_user=True)
198+
else:
199+
break
200+
message_i += 1
201+
202+
user_input = st.text_input("What do you want to learn more about?", key="user_message")
203+
if st.button("Send"):
204+
generate_response(user_input)
205+
st.experimental_rerun()

0 commit comments

Comments
 (0)