|
1 | 1 | """Streamlit App""" |
2 | 2 |
|
| 3 | +import openai |
3 | 4 | import streamlit as st |
4 | 5 | from langchain.schema import Document |
| 6 | +from streamlit_chat import message |
5 | 7 |
|
6 | 8 | from rabbithole import summarize_document |
7 | 9 | from rabbithole.embedding import embed_document |
|
14 | 16 | for state_var in ["uploaded_files", "documents", "embeddings", "keywords", "summaries"]: |
15 | 17 | if state_var not in st.session_state: |
16 | 18 | st.session_state[state_var] = {} |
| 19 | +if "processed" not in st.session_state: |
| 20 | + st.session_state.processed = False |
| 21 | +if "bot_messages" not in st.session_state: |
| 22 | + st.session_state.bot_messages = [ |
| 23 | + "Hello, I am here to help you learn more efficiently" |
| 24 | + ] |
| 25 | +if "user_messages" not in st.session_state: |
| 26 | + st.session_state.user_messages = [] |
| 27 | + |
| 28 | + |
| 29 | +def generate_response(prompt): |
| 30 | + """Generate a response to a prompt using the GPT-3.5 Turbo model""" |
| 31 | + st.session_state['user_messages'].append(prompt) |
| 32 | + |
| 33 | + messages = [] |
| 34 | + # Alternate between bot and the assistant until the conversation is over |
| 35 | + msg_idx = 0 |
| 36 | + while True: |
| 37 | + if len(st.session_state['bot_messages']) > msg_idx: |
| 38 | + messages.append({"role": "assistant", "content": st.session_state['bot_messages'][msg_idx]}) |
| 39 | + else: |
| 40 | + break |
| 41 | + if len(st.session_state['user_message']) > msg_idx: |
| 42 | + messages.append({"role": "user", "content": st.session_state['user_messages'][msg_idx]}) |
| 43 | + else: |
| 44 | + break |
| 45 | + msg_idx += 1 |
| 46 | + print(messages) |
| 47 | + completion = openai.ChatCompletion.create( |
| 48 | + model="gpt-4", |
| 49 | + messages=messages |
| 50 | + ) |
| 51 | + response = completion.choices[0].message.content |
| 52 | + st.session_state['bot_messages'].append(response) |
| 53 | + print(response) |
17 | 54 |
|
18 | 55 |
|
19 | 56 | def load_files_with_spinner(files: list) -> dict[str, list[Document]]: |
@@ -87,56 +124,82 @@ def generate_plan_with_spinner() -> dict: |
87 | 124 |
|
88 | 125 | st.title("RabbitHole") |
89 | 126 |
|
90 | | -uploaded_files = st.file_uploader("Upload content", |
91 | | - type=["docx", "pdf", "txt", *SUPPORTED_IMG_FILE_TYPES, *SUPPORTED_AV_FILE_TYPES], |
92 | | - accept_multiple_files=True) |
93 | | - |
94 | | -if st.button("Dive in"): |
95 | | - if not uploaded_files: |
96 | | - st.warning("Please upload a file first.") |
97 | | - st.stop() |
98 | | - |
99 | | - # Check if uploaded files have changed |
100 | | - uploaded_files_changed = False |
101 | | - if len(uploaded_files) != len(st.session_state.uploaded_files): |
102 | | - uploaded_files_changed = True |
103 | | - else: |
104 | | - for new_file, old_file in zip(uploaded_files, st.session_state.uploaded_files): |
105 | | - if new_file != old_file: |
106 | | - uploaded_files_changed = True |
107 | | - break |
108 | | - |
109 | | - if uploaded_files_changed: |
110 | | - st.session_state.uploaded_files = uploaded_files |
111 | | - |
112 | | - # Load the text from the uploaded PDF files |
113 | | - st.session_state.documents = load_files_with_spinner(st.session_state.uploaded_files) |
114 | | - st.session_state.embeddings = embed_documents_with_spinner(st.session_state.documents) |
115 | | - st.session_state.keywords = extract_keywords_with_spinner(st.session_state.embeddings) |
116 | | - st.session_state.summaries = generate_summary_with_spinner(st.session_state.documents) |
117 | | - |
118 | | - # Display the keywords and summaries |
119 | | - for doc_name, doc_keywords in st.session_state.keywords.items(): |
120 | | - st.header(doc_name) |
121 | | - st.caption("Keywords: " + ", ".join(doc_keywords)) |
122 | | - st.write(st.session_state.summaries[doc_name]) |
123 | | - st.divider() |
124 | | - |
125 | | - # Display the plan |
126 | | - st.header("Study Plan") |
127 | | - plan = generate_plan_with_spinner() |
128 | | - for data in plan.get("plan", []): |
129 | | - for doc_name, doc_data in data.items(): |
130 | | - st.subheader(doc_name) |
131 | | - st.write(f"**Background Concepts**") |
132 | | - for concept in doc_data.get("Background Concepts", []): |
133 | | - st.write(f"- {concept}") |
134 | | - st.write(f"**Key Concepts**") |
135 | | - for concept in doc_data.get("Key Concepts", []): |
136 | | - st.write(f"- {concept}") |
137 | | - st.write(f"**Further Reading**") |
138 | | - for concept in doc_data.get("Further Reading", []): |
139 | | - st.write(f"- {concept}") |
140 | | - st.write("") |
141 | | - |
142 | | - st.success('Summarization completed.') |
| 127 | +if not st.session_state.processed: |
| 128 | + uploaded_files = st.file_uploader("Upload content", |
| 129 | + type=["docx", "pdf", "txt", *SUPPORTED_IMG_FILE_TYPES, *SUPPORTED_AV_FILE_TYPES], |
| 130 | + accept_multiple_files=True) |
| 131 | + |
| 132 | + if st.button("Dive in"): |
| 133 | + if not uploaded_files: |
| 134 | + st.warning("Please upload a file first.") |
| 135 | + st.stop() |
| 136 | + |
| 137 | + # Check if uploaded files have changed |
| 138 | + uploaded_files_changed = False |
| 139 | + if len(uploaded_files) != len(st.session_state.uploaded_files): |
| 140 | + uploaded_files_changed = True |
| 141 | + else: |
| 142 | + for new_file, old_file in zip(uploaded_files, st.session_state.uploaded_files): |
| 143 | + if new_file != old_file: |
| 144 | + uploaded_files_changed = True |
| 145 | + break |
| 146 | + |
| 147 | + if uploaded_files_changed: |
| 148 | + st.session_state.uploaded_files = uploaded_files |
| 149 | + |
| 150 | + # Load the text from the uploaded PDF files |
| 151 | + st.session_state.documents = load_files_with_spinner(st.session_state.uploaded_files) |
| 152 | + st.session_state.embeddings = embed_documents_with_spinner(st.session_state.documents) |
| 153 | + st.session_state.keywords = extract_keywords_with_spinner(st.session_state.embeddings) |
| 154 | + st.session_state.summaries = generate_summary_with_spinner(st.session_state.documents) |
| 155 | + |
| 156 | + # Display the keywords and summaries |
| 157 | + for doc_name, doc_keywords in st.session_state.keywords.items(): |
| 158 | + st.header(doc_name) |
| 159 | + st.caption("Keywords: " + ", ".join(doc_keywords)) |
| 160 | + st.write(st.session_state.summaries[doc_name]) |
| 161 | + st.divider() |
| 162 | + |
| 163 | + # Display the plan |
| 164 | + st.header("Study Plan") |
| 165 | + plan = generate_plan_with_spinner() |
| 166 | + for data in plan.get("plan", []): |
| 167 | + for doc_name, doc_data in data.items(): |
| 168 | + st.subheader(doc_name) |
| 169 | + st.write(f"**Background Concepts**") |
| 170 | + for concept in doc_data.get("Background Concepts", []): |
| 171 | + st.write(f"- {concept}") |
| 172 | + st.write(f"**Key Concepts**") |
| 173 | + for concept in doc_data.get("Key Concepts", []): |
| 174 | + st.write(f"- {concept}") |
| 175 | + st.write(f"**Further Reading**") |
| 176 | + for concept in doc_data.get("Further Reading", []): |
| 177 | + st.write(f"- {concept}") |
| 178 | + st.write("") |
| 179 | + |
| 180 | + st.session_state.processed = True |
| 181 | + st.success('Summarization completed.') |
| 182 | + |
| 183 | +if st.session_state.processed: |
| 184 | + st.header("Loaded Files") |
| 185 | + for file in st.session_state.uploaded_files: |
| 186 | + st.write(file.name) |
| 187 | + |
| 188 | + st.header("Chat") |
| 189 | + # Iterate through the bot and user message and print them alternatively |
| 190 | + message_i = 0 |
| 191 | + while True: |
| 192 | + if len(st.session_state.bot_messages) > message_i: |
| 193 | + message(st.session_state.bot_messages[message_i]) |
| 194 | + else: |
| 195 | + break |
| 196 | + if len(st.session_state.user_messages) > message_i: |
| 197 | + message(st.session_state.user_messages[message_i], is_user=True) |
| 198 | + else: |
| 199 | + break |
| 200 | + message_i += 1 |
| 201 | + |
| 202 | + user_input = st.text_input("What do you want to learn more about?", key="user_message") |
| 203 | + if st.button("Send"): |
| 204 | + generate_response(user_input) |
| 205 | + st.experimental_rerun() |
0 commit comments