File size: 15,746 Bytes
6ca6013
 
 
 
 
 
 
 
 
 
 
1bcaf89
6ca6013
 
 
 
1bcaf89
6ca6013
e199e88
 
 
 
 
1562651
e199e88
 
 
6ca6013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
import os
import tempfile
import pytesseract
import pdfplumber
from PIL import Image
from transformers import pipeline
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from gtts import gTTS
import streamlit as st
import re
import nltk
import random
import torch  # Import torch

# Download NLTK resources if not already downloaded
# nltk_data_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
# os.makedirs(nltk_data_dir, exist_ok=True)
# nltk.download('averaged_perceptron_tagger', download_dir=nltk_data_dir, quiet=True)
# nltk.data.path.append(nltk_data_dir)

nltk_data_dir = "/tmp/nltk_data"
os.makedirs(nltk_data_dir, exist_ok=True)
nltk.download('averaged_perceptron_tagger', download_dir=nltk_data_dir, quiet=True)
nltk.data.path.append(nltk_data_dir)

# Load models
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=torch.device("cpu"))
qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad", device=torch.device("cpu"))
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

vector_dbs = {}  # Dictionary to store multiple vector databases, keyed by document title
extracted_texts = {} # Dictionary to store extracted text, keyed by document title
current_doc_title = None

# ---------------------------------------------
# Extract text from PDF or Image
# ---------------------------------------------
def extract_text(uploaded_file):
    global current_doc_title
    current_doc_title = uploaded_file.name
    suffix = uploaded_file.name.lower()
    with tempfile.NamedTemporaryFile(delete=False) as tmp:
        tmp.write(uploaded_file.read())
        path = tmp.name

    text = ""
    if suffix.endswith(".pdf"):
        with pdfplumber.open(path) as pdf:
            for page in pdf.pages:
                page_text = page.extract_text()
                if page_text:
                    text += page_text + "\n"
    else:
        try:
            text = pytesseract.image_to_string(Image.open(path))
        except Exception as e:
            st.error(f"Error during OCR for {uploaded_file.name}: {e}")
            text = ""

    os.remove(path)
    return text.strip()

# ---------------------------------------------
# Store Embeddings in FAISS
# ---------------------------------------------
def store_vector(text):
    global vector_dbs, current_doc_title
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    docs = text_splitter.create_documents([text])
    for doc in docs:
        doc.metadata = {"title": current_doc_title}  # Add document title as metadata

    if current_doc_title in vector_dbs:
        vector_dbs[current_doc_title].add_documents(docs)  # Append to existing DB
    else:
        vector_dbs[current_doc_title] = FAISS.from_documents(docs, embedding_model)

# ---------------------------------------------
# Summarize Text
# ---------------------------------------------
def summarize(text):
    if len(text.split()) < 100:
        return "Text too short to summarize."
    chunks = [text[i : i + 1024] for i in range(0, len(text), 1024)]
    summaries = []
    for chunk in chunks:
        max_len = int(len(chunk.split()) * 0.6)
        max_len = max(30, min(max_len, 150))
        try:
            summary = summarizer(chunk, max_length=max_len, min_length=20)[0]["summary_text"]
            summaries.append(summary)
        except Exception as e:
            st.error(f"Error during summarization: {e}")
            return "An error occurred during summarization."
    return " ".join(summaries)

# ---------------------------------------------
# Question Answering
# ---------------------------------------------
def topic_search(question, doc_title=None):
    global vector_dbs
    if not vector_dbs:
        st.warning("Please upload and process a file first in the 'Upload & Extract' tab.")
        return ""

    try:
        if doc_title and doc_title in vector_dbs:
            retriever = vector_dbs[doc_title].as_retriever(search_kwargs={"k": 3})
        else:
            combined_docs = []
            for db in vector_dbs.values():
                combined_docs.extend(db.get_relevant_documents(question))
            if not combined_docs:
                 return "No relevant information found across uploaded documents."
            temp_db = FAISS.from_documents(combined_docs, embedding_model)
            retriever = temp_db.as_retriever(search_kwargs={"k": 3})

        relevant_docs = retriever.get_relevant_documents(question)
        context = "\n\n".join([doc.page_content for doc in relevant_docs])
        answer = qa_pipeline(question=question, context=context)["answer"]
        return answer.strip()
    except Exception as e:
        st.error(f"Error during question answering: {e}")
        return "An error occurred while trying to answer the question."

# ---------------------------------------------
# Flashcard Generation
# ---------------------------------------------
def generate_flashcards(text):
    flashcards = []
    seen_terms = set()
    sentences = nltk.sent_tokenize(text)

    for i, sent in enumerate(sentences):
        words = nltk.word_tokenize(sent)
        tagged_words = nltk.pos_tag(words)

        potential_terms = [word for word, tag in tagged_words if tag.startswith('NN') or tag.startswith('NP')]

        for term in potential_terms:
            if term in seen_terms:
                continue

            defining_patterns = [r"\b" + re.escape(term) + r"\b\s+is\s+(?:a|an|the)\s+(.+?)(?:\.|,|\n|$)",
                                 r"\b" + re.escape(term) + r"\b\s+refers\s+to\s+(.+?)(?:\.|,|\n|$)",
                                 r"\b" + re.escape(term) + r"\b\s+means\s+(.+?)(?:\.|,|\n|$)",
                                 r"\b" + re.escape(term) + r"\b,\s+defined\s+as\s+(.+?)(?:\.|,|\n|$)",
                                 r"\b" + re.escape(term) + r"\b:\s+(.+?)(?:\.|,|\n|$)"]

            potential_definitions = []
            for pattern in defining_patterns:
                match = re.search(pattern, sent, re.IGNORECASE)
                if match and len(match.groups()) >= 1:
                    potential_definitions.append(match.group(1).strip())

            for definition in potential_definitions:
                if 2 <= len(definition.split()) <= 30:
                    flashcards.append({"term": term, "definition": definition})
                    seen_terms.add(term)
                    break

            if term not in seen_terms and i > 0:
                prev_sent = sentences[i-1]
                defining_patterns_prev = [r"The\s+\b" + re.escape(term) + r"\b\s+is\s+(.+?)(?:\.|,|\n|$)",
                                          r"This\s+\b" + re.escape(term) + r"\b\s+refers\s+to\s+(.+?)(?:\.|,|\n|$)",
                                          r"It\s+means\s+the\s+\b" + re.escape(term) + r"\b\s+(.+?)(?:\.|,|\n|$)"]
                for pattern in defining_patterns_prev:
                    match = re.search(pattern, prev_sent, re.IGNORECASE)
                    if match and term in sent and len(match.groups()) >= 1:
                        definition = match.group(1).strip()
                        if 2 <= len(definition.split()) <= 30:
                            flashcards.append({"term": term, "definition": definition})
                            seen_terms.add(term)
                            break

    return flashcards

# ---------------------------------------------
# Text to Speech
# ---------------------------------------------
def read_aloud(text):
    try:
        tts = gTTS(text)
        audio_path = os.path.join(tempfile.gettempdir(), "summary.mp3")
        tts.save(audio_path)
        return audio_path
    except Exception as e:
        st.error(f"Error during text-to-speech: {e}")
        return None

# ---------------------------------------------
# Quiz Generation and Handling
# ---------------------------------------------
def generate_quiz_questions(text, num_questions=5):
    flashcards = generate_flashcards(text) # Reuse flashcard logic for potential terms/definitions
    if not flashcards:
        return []

    questions = []
    used_indices = set()
    num_available = len(flashcards)

    while len(questions) < num_questions and len(used_indices) < num_available:
        index = random.randint(0, num_available - 1)
        if index in used_indices:
            continue
        used_indices.add(index)
        card = flashcards[index]
        correct_answer = card['term']
        definition = card['definition']

        # Generate incorrect answers (very basic for now)
        incorrect_options = random.sample([c['term'] for i, c in enumerate(flashcards) if i != index], 3)
        options = [correct_answer] + incorrect_options
        random.shuffle(options)

        questions.append({
            "question": f"What is the term for: {definition}",
            "options": options,
            "correct_answer": correct_answer,
            "user_answer": None # To store user's choice
        })

    return questions

def display_quiz(questions):
    st.session_state.quiz_questions = questions
    st.session_state.user_answers = {}
    st.session_state.quiz_submitted = False
    for i, q in enumerate(st.session_state.quiz_questions):
        st.subheader(f"Question {i + 1}:")
        st.write(q["question"])
        st.session_state.user_answers[i] = st.radio(f"Answer for Question {i + 1}", q["options"])
    st.button("Submit Quiz", on_click=submit_quiz)

def submit_quiz():
    st.session_state.quiz_submitted = True

def grade_quiz():
    if st.session_state.quiz_submitted:
        score = 0
        for i, q in enumerate(st.session_state.quiz_questions):
            user_answer = st.session_state.user_answers.get(i)
            if user_answer == q["correct_answer"]:
                score += 1
                st.success(f"Question {i + 1}: Correct!")
            else:
                st.error(f"Question {i + 1}: Incorrect. Correct answer was: {q['correct_answer']}")
        st.write(f"## Your Score: {score} / {len(st.session_state.quiz_questions)}")

# ---------------------------------------------
# Streamlit Interface with Tabs
# ---------------------------------------------
st.title("📘 AI Study Assistant")

tab1, tab2, tab3, tab4, tab5 = st.tabs(["Upload & Extract", "Summarize", "Question Answering", "Interactive Learning", "Quiz"])

with tab1:
    st.header("📤 Upload and Extract Text")
    uploaded_files = st.file_uploader("Upload multiple PDF or Image files", type=["pdf", "png", "jpg", "jpeg"], accept_multiple_files=True)
    if uploaded_files:
        for file in uploaded_files:
            with st.spinner(f"Extracting text from {file.name}..."):
                extracted_text = extract_text(file)
            if extracted_text:
                extracted_texts[file.name] = extracted_text
                st.success(f"Text Extracted Successfully from {file.name}!")
                with st.expander(f"View Extracted Text from {file.name}"):
                    st.text_area("Extracted Text", value=extracted_text[:3000], height=400)
                store_vector(extracted_text)
            else:
                st.warning(f"Could not extract any text from {file.name}.")

with tab2:
    st.header("📝 Summarize Text")
    doc_titles = list(vector_dbs.keys())
    if doc_titles:
        selected_doc_title_summary = st.selectbox("Summarize document:", doc_titles)
        if st.button("Generate Summary"):
            if selected_doc_title_summary in extracted_texts:
                with st.spinner(f"Summarizing {selected_doc_title_summary}..."):
                    summary = summarize(extracted_texts[selected_doc_title_summary])
                st.subheader("Summary")
                st.write(summary)
                audio_path = read_aloud(summary)
                if audio_path:
                    st.audio(audio_path)
            else:
                st.warning(f"Original text for {selected_doc_title_summary} not found. Please re-upload.")
    else:
        st.info("Please upload and extract a file in the 'Upload & Extract' tab first.")

with tab3:
    st.header("❓ Question Answering")
    doc_titles = list(vector_dbs.keys())
    if doc_titles:
        doc_title = st.selectbox("Search within document:", ["All Documents"] + doc_titles)
        question = st.text_input("Ask a question about the content:")
        if question:
            with st.spinner("Searching for answer..."):
                if doc_title == "All Documents":
                    answer = topic_search(question)
                else:
                    answer = topic_search(question, doc_title=doc_title)
            if answer:
                st.subheader("Answer:")
                st.write(answer)
            else:
                st.warning("Could not find an answer in the selected document(s).")
    else:
        st.info("Please upload and extract a file in the 'Upload & Extract' tab first.")

with tab4:
    st.header("🧠 Interactive Learning: Flashcards")
    doc_titles = list(extracted_texts.keys())
    if doc_titles:
        selected_doc_title_flashcard = st.selectbox("Generate flashcards from document:", doc_titles)
        if st.button("Generate Flashcards"):
            if selected_doc_title_flashcard in extracted_texts:
                with st.spinner(f"Generating flashcards from {selected_doc_title_flashcard}..."):
                    flashcards = generate_flashcards(extracted_texts[selected_doc_title_flashcard])
                if flashcards:
                    st.subheader("Flashcards")
                    for i, card in enumerate(flashcards):
                        with st.expander(f"Card {i+1}"):
                            st.markdown(f"*Term:* {card['term']}")
                            st.markdown(f"*Definition:* {card['definition']}")
                else:
                    st.info("No flashcards could be generated from this document using the current method.")
            else:
                st.warning(f"Original text for {selected_doc_title_flashcard} not found. Please re-upload.")
    else:
        st.info("Please upload and extract a file in the 'Upload & Extract' tab first.")

with tab5:
    st.header("📝 Quiz Yourself!")
    doc_titles = list(extracted_texts.keys())
    if doc_titles:
        selected_doc_title_quiz = st.selectbox("Generate quiz from document:", doc_titles)
        if selected_doc_title_quiz in extracted_texts:
            text_for_quiz =extracted_texts[selected_doc_title_quiz]
            if "quiz_questions" not in st.session_state:
                st.session_state.quiz_questions = generate_quiz_questions(text_for_quiz)

            if st.session_state.quiz_questions:
                display_quiz(st.session_state.quiz_questions)
                if st.session_state.quiz_submitted:
                    grade_quiz()

                if st.button("Refresh Questions"):
                    st.session_state.quiz_questions = generate_quiz_questions(text_for_quiz)
                    st.session_state.quiz_submitted = False
                    st.session_state.user_answers = {}
                    st.rerun() # Force a re-render to show new questions
            else:
                st.info("Could not generate quiz questions from the current document.")
        else:
            st.warning(f"Original text for {selected_doc_title_quiz} not found. Please re-upload.")
    else:
        st.info("Please upload and extract a file in the 'Upload & Extract' tab first.")