inam09 commited on
Commit
6ca6013
·
verified ·
1 Parent(s): ca405e4

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +355 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,357 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
 
 
 
 
 
4
  import streamlit as st
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import pytesseract
4
+ import pdfplumber
5
+ from PIL import Image
6
+ from transformers import pipeline
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain.docstore.document import Document
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from gtts import gTTS
12
  import streamlit as st
13
+ import re
14
+ import nltk
15
+ import random
16
+ import torch # Import torch
17
 
18
+ # Download NLTK resources if not already downloaded
19
+ nltk.download('averaged_perceptron_tagger', quiet=True)
20
+ nltk.download('punkt', quiet=True)
21
+ nltk.download('punkt_tab', quiet=True)
22
+
23
+ # Load models
24
+ summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=torch.device("cpu"))
25
+ qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad", device=torch.device("cpu"))
26
+ embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
27
+
28
+ vector_dbs = {} # Dictionary to store multiple vector databases, keyed by document title
29
+ extracted_texts = {} # Dictionary to store extracted text, keyed by document title
30
+ current_doc_title = None
31
+
32
+ # ---------------------------------------------
33
+ # Extract text from PDF or Image
34
+ # ---------------------------------------------
35
+ def extract_text(uploaded_file):
36
+ global current_doc_title
37
+ current_doc_title = uploaded_file.name
38
+ suffix = uploaded_file.name.lower()
39
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
40
+ tmp.write(uploaded_file.read())
41
+ path = tmp.name
42
+
43
+ text = ""
44
+ if suffix.endswith(".pdf"):
45
+ with pdfplumber.open(path) as pdf:
46
+ for page in pdf.pages:
47
+ page_text = page.extract_text()
48
+ if page_text:
49
+ text += page_text + "\n"
50
+ else:
51
+ try:
52
+ text = pytesseract.image_to_string(Image.open(path))
53
+ except Exception as e:
54
+ st.error(f"Error during OCR for {uploaded_file.name}: {e}")
55
+ text = ""
56
+
57
+ os.remove(path)
58
+ return text.strip()
59
+
60
+ # ---------------------------------------------
61
+ # Store Embeddings in FAISS
62
+ # ---------------------------------------------
63
+ def store_vector(text):
64
+ global vector_dbs, current_doc_title
65
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
66
+ docs = text_splitter.create_documents([text])
67
+ for doc in docs:
68
+ doc.metadata = {"title": current_doc_title} # Add document title as metadata
69
+
70
+ if current_doc_title in vector_dbs:
71
+ vector_dbs[current_doc_title].add_documents(docs) # Append to existing DB
72
+ else:
73
+ vector_dbs[current_doc_title] = FAISS.from_documents(docs, embedding_model)
74
+
75
+ # ---------------------------------------------
76
+ # Summarize Text
77
+ # ---------------------------------------------
78
+ def summarize(text):
79
+ if len(text.split()) < 100:
80
+ return "Text too short to summarize."
81
+ chunks = [text[i : i + 1024] for i in range(0, len(text), 1024)]
82
+ summaries = []
83
+ for chunk in chunks:
84
+ max_len = int(len(chunk.split()) * 0.6)
85
+ max_len = max(30, min(max_len, 150))
86
+ try:
87
+ summary = summarizer(chunk, max_length=max_len, min_length=20)[0]["summary_text"]
88
+ summaries.append(summary)
89
+ except Exception as e:
90
+ st.error(f"Error during summarization: {e}")
91
+ return "An error occurred during summarization."
92
+ return " ".join(summaries)
93
+
94
+ # ---------------------------------------------
95
+ # Question Answering
96
+ # ---------------------------------------------
97
+ def topic_search(question, doc_title=None):
98
+ global vector_dbs
99
+ if not vector_dbs:
100
+ st.warning("Please upload and process a file first in the 'Upload & Extract' tab.")
101
+ return ""
102
+
103
+ try:
104
+ if doc_title and doc_title in vector_dbs:
105
+ retriever = vector_dbs[doc_title].as_retriever(search_kwargs={"k": 3})
106
+ else:
107
+ combined_docs = []
108
+ for db in vector_dbs.values():
109
+ combined_docs.extend(db.get_relevant_documents(question))
110
+ if not combined_docs:
111
+ return "No relevant information found across uploaded documents."
112
+ temp_db = FAISS.from_documents(combined_docs, embedding_model)
113
+ retriever = temp_db.as_retriever(search_kwargs={"k": 3})
114
+
115
+ relevant_docs = retriever.get_relevant_documents(question)
116
+ context = "\n\n".join([doc.page_content for doc in relevant_docs])
117
+ answer = qa_pipeline(question=question, context=context)["answer"]
118
+ return answer.strip()
119
+ except Exception as e:
120
+ st.error(f"Error during question answering: {e}")
121
+ return "An error occurred while trying to answer the question."
122
+
123
+ # ---------------------------------------------
124
+ # Flashcard Generation
125
+ # ---------------------------------------------
126
+ def generate_flashcards(text):
127
+ flashcards = []
128
+ seen_terms = set()
129
+ sentences = nltk.sent_tokenize(text)
130
+
131
+ for i, sent in enumerate(sentences):
132
+ words = nltk.word_tokenize(sent)
133
+ tagged_words = nltk.pos_tag(words)
134
+
135
+ potential_terms = [word for word, tag in tagged_words if tag.startswith('NN') or tag.startswith('NP')]
136
+
137
+ for term in potential_terms:
138
+ if term in seen_terms:
139
+ continue
140
+
141
+ defining_patterns = [r"\b" + re.escape(term) + r"\b\s+is\s+(?:a|an|the)\s+(.+?)(?:\.|,|\n|$)",
142
+ r"\b" + re.escape(term) + r"\b\s+refers\s+to\s+(.+?)(?:\.|,|\n|$)",
143
+ r"\b" + re.escape(term) + r"\b\s+means\s+(.+?)(?:\.|,|\n|$)",
144
+ r"\b" + re.escape(term) + r"\b,\s+defined\s+as\s+(.+?)(?:\.|,|\n|$)",
145
+ r"\b" + re.escape(term) + r"\b:\s+(.+?)(?:\.|,|\n|$)"]
146
+
147
+ potential_definitions = []
148
+ for pattern in defining_patterns:
149
+ match = re.search(pattern, sent, re.IGNORECASE)
150
+ if match and len(match.groups()) >= 1:
151
+ potential_definitions.append(match.group(1).strip())
152
+
153
+ for definition in potential_definitions:
154
+ if 2 <= len(definition.split()) <= 30:
155
+ flashcards.append({"term": term, "definition": definition})
156
+ seen_terms.add(term)
157
+ break
158
+
159
+ if term not in seen_terms and i > 0:
160
+ prev_sent = sentences[i-1]
161
+ defining_patterns_prev = [r"The\s+\b" + re.escape(term) + r"\b\s+is\s+(.+?)(?:\.|,|\n|$)",
162
+ r"This\s+\b" + re.escape(term) + r"\b\s+refers\s+to\s+(.+?)(?:\.|,|\n|$)",
163
+ r"It\s+means\s+the\s+\b" + re.escape(term) + r"\b\s+(.+?)(?:\.|,|\n|$)"]
164
+ for pattern in defining_patterns_prev:
165
+ match = re.search(pattern, prev_sent, re.IGNORECASE)
166
+ if match and term in sent and len(match.groups()) >= 1:
167
+ definition = match.group(1).strip()
168
+ if 2 <= len(definition.split()) <= 30:
169
+ flashcards.append({"term": term, "definition": definition})
170
+ seen_terms.add(term)
171
+ break
172
+
173
+ return flashcards
174
+
175
+ # ---------------------------------------------
176
+ # Text to Speech
177
+ # ---------------------------------------------
178
+ def read_aloud(text):
179
+ try:
180
+ tts = gTTS(text)
181
+ audio_path = os.path.join(tempfile.gettempdir(), "summary.mp3")
182
+ tts.save(audio_path)
183
+ return audio_path
184
+ except Exception as e:
185
+ st.error(f"Error during text-to-speech: {e}")
186
+ return None
187
+
188
+ # ---------------------------------------------
189
+ # Quiz Generation and Handling
190
+ # ---------------------------------------------
191
+ def generate_quiz_questions(text, num_questions=5):
192
+ flashcards = generate_flashcards(text) # Reuse flashcard logic for potential terms/definitions
193
+ if not flashcards:
194
+ return []
195
+
196
+ questions = []
197
+ used_indices = set()
198
+ num_available = len(flashcards)
199
+
200
+ while len(questions) < num_questions and len(used_indices) < num_available:
201
+ index = random.randint(0, num_available - 1)
202
+ if index in used_indices:
203
+ continue
204
+ used_indices.add(index)
205
+ card = flashcards[index]
206
+ correct_answer = card['term']
207
+ definition = card['definition']
208
+
209
+ # Generate incorrect answers (very basic for now)
210
+ incorrect_options = random.sample([c['term'] for i, c in enumerate(flashcards) if i != index], 3)
211
+ options = [correct_answer] + incorrect_options
212
+ random.shuffle(options)
213
+
214
+ questions.append({
215
+ "question": f"What is the term for: {definition}",
216
+ "options": options,
217
+ "correct_answer": correct_answer,
218
+ "user_answer": None # To store user's choice
219
+ })
220
+
221
+ return questions
222
+
223
+ def display_quiz(questions):
224
+ st.session_state.quiz_questions = questions
225
+ st.session_state.user_answers = {}
226
+ st.session_state.quiz_submitted = False
227
+ for i, q in enumerate(st.session_state.quiz_questions):
228
+ st.subheader(f"Question {i + 1}:")
229
+ st.write(q["question"])
230
+ st.session_state.user_answers[i] = st.radio(f"Answer for Question {i + 1}", q["options"])
231
+ st.button("Submit Quiz", on_click=submit_quiz)
232
+
233
+ def submit_quiz():
234
+ st.session_state.quiz_submitted = True
235
+
236
+ def grade_quiz():
237
+ if st.session_state.quiz_submitted:
238
+ score = 0
239
+ for i, q in enumerate(st.session_state.quiz_questions):
240
+ user_answer = st.session_state.user_answers.get(i)
241
+ if user_answer == q["correct_answer"]:
242
+ score += 1
243
+ st.success(f"Question {i + 1}: Correct!")
244
+ else:
245
+ st.error(f"Question {i + 1}: Incorrect. Correct answer was: {q['correct_answer']}")
246
+ st.write(f"## Your Score: {score} / {len(st.session_state.quiz_questions)}")
247
+
248
+ # ---------------------------------------------
249
+ # Streamlit Interface with Tabs
250
+ # ---------------------------------------------
251
+ st.title("📘 AI Study Assistant")
252
+
253
+ tab1, tab2, tab3, tab4, tab5 = st.tabs(["Upload & Extract", "Summarize", "Question Answering", "Interactive Learning", "Quiz"])
254
+
255
+ with tab1:
256
+ st.header("📤 Upload and Extract Text")
257
+ uploaded_files = st.file_uploader("Upload multiple PDF or Image files", type=["pdf", "png", "jpg", "jpeg"], accept_multiple_files=True)
258
+ if uploaded_files:
259
+ for file in uploaded_files:
260
+ with st.spinner(f"Extracting text from {file.name}..."):
261
+ extracted_text = extract_text(file)
262
+ if extracted_text:
263
+ extracted_texts[file.name] = extracted_text
264
+ st.success(f"Text Extracted Successfully from {file.name}!")
265
+ with st.expander(f"View Extracted Text from {file.name}"):
266
+ st.text_area("Extracted Text", value=extracted_text[:3000], height=400)
267
+ store_vector(extracted_text)
268
+ else:
269
+ st.warning(f"Could not extract any text from {file.name}.")
270
+
271
+ with tab2:
272
+ st.header("📝 Summarize Text")
273
+ doc_titles = list(vector_dbs.keys())
274
+ if doc_titles:
275
+ selected_doc_title_summary = st.selectbox("Summarize document:", doc_titles)
276
+ if st.button("Generate Summary"):
277
+ if selected_doc_title_summary in extracted_texts:
278
+ with st.spinner(f"Summarizing {selected_doc_title_summary}..."):
279
+ summary = summarize(extracted_texts[selected_doc_title_summary])
280
+ st.subheader("Summary")
281
+ st.write(summary)
282
+ audio_path = read_aloud(summary)
283
+ if audio_path:
284
+ st.audio(audio_path)
285
+ else:
286
+ st.warning(f"Original text for {selected_doc_title_summary} not found. Please re-upload.")
287
+ else:
288
+ st.info("Please upload and extract a file in the 'Upload & Extract' tab first.")
289
+
290
+ with tab3:
291
+ st.header("❓ Question Answering")
292
+ doc_titles = list(vector_dbs.keys())
293
+ if doc_titles:
294
+ doc_title = st.selectbox("Search within document:", ["All Documents"] + doc_titles)
295
+ question = st.text_input("Ask a question about the content:")
296
+ if question:
297
+ with st.spinner("Searching for answer..."):
298
+ if doc_title == "All Documents":
299
+ answer = topic_search(question)
300
+ else:
301
+ answer = topic_search(question, doc_title=doc_title)
302
+ if answer:
303
+ st.subheader("Answer:")
304
+ st.write(answer)
305
+ else:
306
+ st.warning("Could not find an answer in the selected document(s).")
307
+ else:
308
+ st.info("Please upload and extract a file in the 'Upload & Extract' tab first.")
309
+
310
+ with tab4:
311
+ st.header("🧠 Interactive Learning: Flashcards")
312
+ doc_titles = list(extracted_texts.keys())
313
+ if doc_titles:
314
+ selected_doc_title_flashcard = st.selectbox("Generate flashcards from document:", doc_titles)
315
+ if st.button("Generate Flashcards"):
316
+ if selected_doc_title_flashcard in extracted_texts:
317
+ with st.spinner(f"Generating flashcards from {selected_doc_title_flashcard}..."):
318
+ flashcards = generate_flashcards(extracted_texts[selected_doc_title_flashcard])
319
+ if flashcards:
320
+ st.subheader("Flashcards")
321
+ for i, card in enumerate(flashcards):
322
+ with st.expander(f"Card {i+1}"):
323
+ st.markdown(f"*Term:* {card['term']}")
324
+ st.markdown(f"*Definition:* {card['definition']}")
325
+ else:
326
+ st.info("No flashcards could be generated from this document using the current method.")
327
+ else:
328
+ st.warning(f"Original text for {selected_doc_title_flashcard} not found. Please re-upload.")
329
+ else:
330
+ st.info("Please upload and extract a file in the 'Upload & Extract' tab first.")
331
+
332
+ with tab5:
333
+ st.header("📝 Quiz Yourself!")
334
+ doc_titles = list(extracted_texts.keys())
335
+ if doc_titles:
336
+ selected_doc_title_quiz = st.selectbox("Generate quiz from document:", doc_titles)
337
+ if selected_doc_title_quiz in extracted_texts:
338
+ text_for_quiz =extracted_texts[selected_doc_title_quiz]
339
+ if "quiz_questions" not in st.session_state:
340
+ st.session_state.quiz_questions = generate_quiz_questions(text_for_quiz)
341
+
342
+ if st.session_state.quiz_questions:
343
+ display_quiz(st.session_state.quiz_questions)
344
+ if st.session_state.quiz_submitted:
345
+ grade_quiz()
346
+
347
+ if st.button("Refresh Questions"):
348
+ st.session_state.quiz_questions = generate_quiz_questions(text_for_quiz)
349
+ st.session_state.quiz_submitted = False
350
+ st.session_state.user_answers = {}
351
+ st.rerun() # Force a re-render to show new questions
352
+ else:
353
+ st.info("Could not generate quiz questions from the current document.")
354
+ else:
355
+ st.warning(f"Original text for {selected_doc_title_quiz} not found. Please re-upload.")
356
+ else:
357
+ st.info("Please upload and extract a file in the 'Upload & Extract' tab first.")