Skip to content

Commit 7c039b5

Browse files
authored
Merge pull request #5 from punitarani/chat
Add chat bot
2 parents 20ef688 + 721d715 commit 7c039b5

1 file changed

Lines changed: 111 additions & 29 deletions

File tree

app.py

Lines changed: 111 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Streamlit App"""
22

3+
import openai
34
import streamlit as st
45
from langchain.schema import Document
6+
from streamlit_chat import message
57

68
from rabbithole import summarize_document
79
from rabbithole.embedding import embed_document
@@ -10,11 +12,47 @@
1012
from rabbithole.mp3 import SUPPORTED_AV_FILE_TYPES
1113
from rabbithole.planner import generate_plan
1214

13-
# Global variables
14-
global_documents = {}
15-
global_embeddings = {}
16-
global_keywords = {}
17-
global_summaries = {}
15+
# Session variables
16+
for state_var in ["uploaded_files", "documents", "embeddings", "keywords", "summaries"]:
17+
if state_var not in st.session_state:
18+
st.session_state[state_var] = {}
19+
if "plan" not in st.session_state:
20+
st.session_state.plan = None
21+
if "processed" not in st.session_state:
22+
st.session_state.processed = False
23+
if "bot_messages" not in st.session_state:
24+
st.session_state.bot_messages = [
25+
"Hello, I am here to help you learn more efficiently"
26+
]
27+
if "user_messages" not in st.session_state:
28+
st.session_state.user_messages = []
29+
30+
31+
def generate_response(prompt):
32+
"""Generate a response to a prompt using the GPT-3.5 Turbo model"""
33+
st.session_state['user_messages'].append(prompt)
34+
35+
messages = []
36+
# Alternate between bot and the assistant until the conversation is over
37+
msg_idx = 0
38+
while True:
39+
if len(st.session_state['bot_messages']) > msg_idx:
40+
messages.append({"role": "assistant", "content": st.session_state['bot_messages'][msg_idx]})
41+
else:
42+
break
43+
if len(st.session_state['user_message']) > msg_idx:
44+
messages.append({"role": "user", "content": st.session_state['user_messages'][msg_idx]})
45+
else:
46+
break
47+
msg_idx += 1
48+
print(messages)
49+
completion = openai.ChatCompletion.create(
50+
model="gpt-4",
51+
messages=messages
52+
)
53+
response = completion.choices[0].message.content
54+
st.session_state['bot_messages'].append(response)
55+
print(response)
1856

1957

2058
def load_files_with_spinner(files: list) -> dict[str, list[Document]]:
@@ -80,39 +118,65 @@ def generate_summary_with_spinner(documents: dict[str, list[Document]]) -> dict[
80118
def generate_plan_with_spinner() -> dict:
81119
"""Generate a logical plan to study the uploaded documents."""
82120
with st.spinner("Generating plan..."):
83-
plan = generate_plan(global_summaries, global_keywords)
121+
plan = generate_plan(st.session_state.summaries, st.session_state.keywords)
84122
return plan
85123

86124

87125
st.set_page_config(page_title="RabbitHole", page_icon="🐇", layout="wide")
88126

89127
st.title("RabbitHole")
90128

91-
uploaded_files = st.file_uploader("Upload content",
92-
type=["docx", "pdf", "txt", *SUPPORTED_IMG_FILE_TYPES, *SUPPORTED_AV_FILE_TYPES],
93-
accept_multiple_files=True)
94-
95-
if st.button("Dive in"):
96-
if not uploaded_files:
97-
st.warning("Please upload a file first.")
98-
st.stop()
99-
100-
# Load the text from the uploaded PDF files
101-
global_documents = load_files_with_spinner(uploaded_files)
102-
global_embeddings = embed_documents_with_spinner(global_documents)
103-
global_keywords = extract_keywords_with_spinner(global_embeddings)
104-
global_summaries = generate_summary_with_spinner(global_documents)
105-
106-
# Display the keywords and summaries
107-
for doc_name, doc_keywords in global_keywords.items():
108-
st.header(doc_name)
109-
st.caption("Keywords: " + ", ".join(doc_keywords))
110-
st.write(global_summaries[doc_name])
111-
st.divider()
129+
if not st.session_state.processed:
130+
uploaded_files = st.file_uploader("Upload content",
131+
type=["docx", "pdf", "txt", *SUPPORTED_IMG_FILE_TYPES, *SUPPORTED_AV_FILE_TYPES],
132+
accept_multiple_files=True)
133+
134+
if st.button("Dive in"):
135+
if not uploaded_files:
136+
st.warning("Please upload a file first.")
137+
st.stop()
138+
139+
# Check if uploaded files have changed
140+
uploaded_files_changed = False
141+
if len(uploaded_files) != len(st.session_state.uploaded_files):
142+
uploaded_files_changed = True
143+
else:
144+
for new_file, old_file in zip(uploaded_files, st.session_state.uploaded_files):
145+
if new_file != old_file:
146+
uploaded_files_changed = True
147+
break
148+
149+
if uploaded_files_changed:
150+
st.session_state.uploaded_files = uploaded_files
151+
152+
# Load the text from the uploaded PDF files
153+
st.session_state.documents = load_files_with_spinner(st.session_state.uploaded_files)
154+
st.session_state.embeddings = embed_documents_with_spinner(st.session_state.documents)
155+
st.session_state.keywords = extract_keywords_with_spinner(st.session_state.embeddings)
156+
st.session_state.summaries = generate_summary_with_spinner(st.session_state.documents)
157+
158+
# Display the keywords and summaries
159+
for doc_name, doc_keywords in st.session_state.keywords.items():
160+
st.header(doc_name)
161+
st.caption("Keywords: " + ", ".join(doc_keywords))
162+
st.write(st.session_state.summaries[doc_name])
163+
st.divider()
164+
165+
st.session_state.processed = True
166+
st.success('Summarization completed.')
167+
168+
if st.session_state.processed:
169+
st.header("Loaded Files")
170+
for file in st.session_state.uploaded_files:
171+
st.write(file.name)
112172

113173
# Display the plan
114174
st.header("Study Plan")
115-
plan = generate_plan_with_spinner()
175+
if st.session_state.plan is None:
176+
plan = generate_plan_with_spinner()
177+
st.session_state.plan = plan
178+
else:
179+
plan = st.session_state.plan
116180
for data in plan.get("plan", []):
117181
for doc_name, doc_data in data.items():
118182
st.subheader(doc_name)
@@ -127,4 +191,22 @@ def generate_plan_with_spinner() -> dict:
127191
st.write(f"- {concept}")
128192
st.write("")
129193

130-
st.success('Summarization completed.')
194+
st.header("Chat")
195+
# Iterate through the bot and user message and print them alternatively
196+
message_i = 0
197+
while True:
198+
if len(st.session_state.bot_messages) > message_i:
199+
message(st.session_state.bot_messages[message_i])
200+
else:
201+
break
202+
if len(st.session_state.user_messages) > message_i:
203+
message(st.session_state.user_messages[message_i], is_user=True)
204+
else:
205+
break
206+
message_i += 1
207+
208+
user_input = st.text_input("What do you want to learn more about?", key="user_message")
209+
if st.button("Send"):
210+
with st.spinner("Generating response..."):
211+
generate_response(user_input)
212+
st.experimental_rerun()

0 commit comments

Comments
 (0)