Spaces:
Sleeping
Sleeping
File size: 15,746 Bytes
6ca6013 1bcaf89 6ca6013 1bcaf89 6ca6013 e199e88 1562651 e199e88 6ca6013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 |
import os
import tempfile
import pytesseract
import pdfplumber
from PIL import Image
from transformers import pipeline
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from gtts import gTTS
import streamlit as st
import re
import nltk
import random
import torch # Import torch
# Download NLTK resources if not already downloaded
# nltk_data_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
# os.makedirs(nltk_data_dir, exist_ok=True)
# nltk.download('averaged_perceptron_tagger', download_dir=nltk_data_dir, quiet=True)
# nltk.data.path.append(nltk_data_dir)
nltk_data_dir = "/tmp/nltk_data"
os.makedirs(nltk_data_dir, exist_ok=True)
nltk.download('averaged_perceptron_tagger', download_dir=nltk_data_dir, quiet=True)
nltk.data.path.append(nltk_data_dir)
# Load models
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=torch.device("cpu"))
qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad", device=torch.device("cpu"))
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vector_dbs = {} # Dictionary to store multiple vector databases, keyed by document title
extracted_texts = {} # Dictionary to store extracted text, keyed by document title
current_doc_title = None
# ---------------------------------------------
# Extract text from PDF or Image
# ---------------------------------------------
def extract_text(uploaded_file):
global current_doc_title
current_doc_title = uploaded_file.name
suffix = uploaded_file.name.lower()
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(uploaded_file.read())
path = tmp.name
text = ""
if suffix.endswith(".pdf"):
with pdfplumber.open(path) as pdf:
for page in pdf.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n"
else:
try:
text = pytesseract.image_to_string(Image.open(path))
except Exception as e:
st.error(f"Error during OCR for {uploaded_file.name}: {e}")
text = ""
os.remove(path)
return text.strip()
# ---------------------------------------------
# Store Embeddings in FAISS
# ---------------------------------------------
def store_vector(text):
global vector_dbs, current_doc_title
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
docs = text_splitter.create_documents([text])
for doc in docs:
doc.metadata = {"title": current_doc_title} # Add document title as metadata
if current_doc_title in vector_dbs:
vector_dbs[current_doc_title].add_documents(docs) # Append to existing DB
else:
vector_dbs[current_doc_title] = FAISS.from_documents(docs, embedding_model)
# ---------------------------------------------
# Summarize Text
# ---------------------------------------------
def summarize(text):
if len(text.split()) < 100:
return "Text too short to summarize."
chunks = [text[i : i + 1024] for i in range(0, len(text), 1024)]
summaries = []
for chunk in chunks:
max_len = int(len(chunk.split()) * 0.6)
max_len = max(30, min(max_len, 150))
try:
summary = summarizer(chunk, max_length=max_len, min_length=20)[0]["summary_text"]
summaries.append(summary)
except Exception as e:
st.error(f"Error during summarization: {e}")
return "An error occurred during summarization."
return " ".join(summaries)
# ---------------------------------------------
# Question Answering
# ---------------------------------------------
def topic_search(question, doc_title=None):
global vector_dbs
if not vector_dbs:
st.warning("Please upload and process a file first in the 'Upload & Extract' tab.")
return ""
try:
if doc_title and doc_title in vector_dbs:
retriever = vector_dbs[doc_title].as_retriever(search_kwargs={"k": 3})
else:
combined_docs = []
for db in vector_dbs.values():
combined_docs.extend(db.get_relevant_documents(question))
if not combined_docs:
return "No relevant information found across uploaded documents."
temp_db = FAISS.from_documents(combined_docs, embedding_model)
retriever = temp_db.as_retriever(search_kwargs={"k": 3})
relevant_docs = retriever.get_relevant_documents(question)
context = "\n\n".join([doc.page_content for doc in relevant_docs])
answer = qa_pipeline(question=question, context=context)["answer"]
return answer.strip()
except Exception as e:
st.error(f"Error during question answering: {e}")
return "An error occurred while trying to answer the question."
# ---------------------------------------------
# Flashcard Generation
# ---------------------------------------------
def generate_flashcards(text):
flashcards = []
seen_terms = set()
sentences = nltk.sent_tokenize(text)
for i, sent in enumerate(sentences):
words = nltk.word_tokenize(sent)
tagged_words = nltk.pos_tag(words)
potential_terms = [word for word, tag in tagged_words if tag.startswith('NN') or tag.startswith('NP')]
for term in potential_terms:
if term in seen_terms:
continue
defining_patterns = [r"\b" + re.escape(term) + r"\b\s+is\s+(?:a|an|the)\s+(.+?)(?:\.|,|\n|$)",
r"\b" + re.escape(term) + r"\b\s+refers\s+to\s+(.+?)(?:\.|,|\n|$)",
r"\b" + re.escape(term) + r"\b\s+means\s+(.+?)(?:\.|,|\n|$)",
r"\b" + re.escape(term) + r"\b,\s+defined\s+as\s+(.+?)(?:\.|,|\n|$)",
r"\b" + re.escape(term) + r"\b:\s+(.+?)(?:\.|,|\n|$)"]
potential_definitions = []
for pattern in defining_patterns:
match = re.search(pattern, sent, re.IGNORECASE)
if match and len(match.groups()) >= 1:
potential_definitions.append(match.group(1).strip())
for definition in potential_definitions:
if 2 <= len(definition.split()) <= 30:
flashcards.append({"term": term, "definition": definition})
seen_terms.add(term)
break
if term not in seen_terms and i > 0:
prev_sent = sentences[i-1]
defining_patterns_prev = [r"The\s+\b" + re.escape(term) + r"\b\s+is\s+(.+?)(?:\.|,|\n|$)",
r"This\s+\b" + re.escape(term) + r"\b\s+refers\s+to\s+(.+?)(?:\.|,|\n|$)",
r"It\s+means\s+the\s+\b" + re.escape(term) + r"\b\s+(.+?)(?:\.|,|\n|$)"]
for pattern in defining_patterns_prev:
match = re.search(pattern, prev_sent, re.IGNORECASE)
if match and term in sent and len(match.groups()) >= 1:
definition = match.group(1).strip()
if 2 <= len(definition.split()) <= 30:
flashcards.append({"term": term, "definition": definition})
seen_terms.add(term)
break
return flashcards
# ---------------------------------------------
# Text to Speech
# ---------------------------------------------
def read_aloud(text):
try:
tts = gTTS(text)
audio_path = os.path.join(tempfile.gettempdir(), "summary.mp3")
tts.save(audio_path)
return audio_path
except Exception as e:
st.error(f"Error during text-to-speech: {e}")
return None
# ---------------------------------------------
# Quiz Generation and Handling
# ---------------------------------------------
def generate_quiz_questions(text, num_questions=5):
flashcards = generate_flashcards(text) # Reuse flashcard logic for potential terms/definitions
if not flashcards:
return []
questions = []
used_indices = set()
num_available = len(flashcards)
while len(questions) < num_questions and len(used_indices) < num_available:
index = random.randint(0, num_available - 1)
if index in used_indices:
continue
used_indices.add(index)
card = flashcards[index]
correct_answer = card['term']
definition = card['definition']
# Generate incorrect answers (very basic for now)
incorrect_options = random.sample([c['term'] for i, c in enumerate(flashcards) if i != index], 3)
options = [correct_answer] + incorrect_options
random.shuffle(options)
questions.append({
"question": f"What is the term for: {definition}",
"options": options,
"correct_answer": correct_answer,
"user_answer": None # To store user's choice
})
return questions
def display_quiz(questions):
st.session_state.quiz_questions = questions
st.session_state.user_answers = {}
st.session_state.quiz_submitted = False
for i, q in enumerate(st.session_state.quiz_questions):
st.subheader(f"Question {i + 1}:")
st.write(q["question"])
st.session_state.user_answers[i] = st.radio(f"Answer for Question {i + 1}", q["options"])
st.button("Submit Quiz", on_click=submit_quiz)
def submit_quiz():
st.session_state.quiz_submitted = True
def grade_quiz():
if st.session_state.quiz_submitted:
score = 0
for i, q in enumerate(st.session_state.quiz_questions):
user_answer = st.session_state.user_answers.get(i)
if user_answer == q["correct_answer"]:
score += 1
st.success(f"Question {i + 1}: Correct!")
else:
st.error(f"Question {i + 1}: Incorrect. Correct answer was: {q['correct_answer']}")
st.write(f"## Your Score: {score} / {len(st.session_state.quiz_questions)}")
# ---------------------------------------------
# Streamlit Interface with Tabs
# ---------------------------------------------
st.title("📘 AI Study Assistant")
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Upload & Extract", "Summarize", "Question Answering", "Interactive Learning", "Quiz"])
with tab1:
st.header("📤 Upload and Extract Text")
uploaded_files = st.file_uploader("Upload multiple PDF or Image files", type=["pdf", "png", "jpg", "jpeg"], accept_multiple_files=True)
if uploaded_files:
for file in uploaded_files:
with st.spinner(f"Extracting text from {file.name}..."):
extracted_text = extract_text(file)
if extracted_text:
extracted_texts[file.name] = extracted_text
st.success(f"Text Extracted Successfully from {file.name}!")
with st.expander(f"View Extracted Text from {file.name}"):
st.text_area("Extracted Text", value=extracted_text[:3000], height=400)
store_vector(extracted_text)
else:
st.warning(f"Could not extract any text from {file.name}.")
with tab2:
st.header("📝 Summarize Text")
doc_titles = list(vector_dbs.keys())
if doc_titles:
selected_doc_title_summary = st.selectbox("Summarize document:", doc_titles)
if st.button("Generate Summary"):
if selected_doc_title_summary in extracted_texts:
with st.spinner(f"Summarizing {selected_doc_title_summary}..."):
summary = summarize(extracted_texts[selected_doc_title_summary])
st.subheader("Summary")
st.write(summary)
audio_path = read_aloud(summary)
if audio_path:
st.audio(audio_path)
else:
st.warning(f"Original text for {selected_doc_title_summary} not found. Please re-upload.")
else:
st.info("Please upload and extract a file in the 'Upload & Extract' tab first.")
with tab3:
st.header("❓ Question Answering")
doc_titles = list(vector_dbs.keys())
if doc_titles:
doc_title = st.selectbox("Search within document:", ["All Documents"] + doc_titles)
question = st.text_input("Ask a question about the content:")
if question:
with st.spinner("Searching for answer..."):
if doc_title == "All Documents":
answer = topic_search(question)
else:
answer = topic_search(question, doc_title=doc_title)
if answer:
st.subheader("Answer:")
st.write(answer)
else:
st.warning("Could not find an answer in the selected document(s).")
else:
st.info("Please upload and extract a file in the 'Upload & Extract' tab first.")
with tab4:
st.header("🧠 Interactive Learning: Flashcards")
doc_titles = list(extracted_texts.keys())
if doc_titles:
selected_doc_title_flashcard = st.selectbox("Generate flashcards from document:", doc_titles)
if st.button("Generate Flashcards"):
if selected_doc_title_flashcard in extracted_texts:
with st.spinner(f"Generating flashcards from {selected_doc_title_flashcard}..."):
flashcards = generate_flashcards(extracted_texts[selected_doc_title_flashcard])
if flashcards:
st.subheader("Flashcards")
for i, card in enumerate(flashcards):
with st.expander(f"Card {i+1}"):
st.markdown(f"*Term:* {card['term']}")
st.markdown(f"*Definition:* {card['definition']}")
else:
st.info("No flashcards could be generated from this document using the current method.")
else:
st.warning(f"Original text for {selected_doc_title_flashcard} not found. Please re-upload.")
else:
st.info("Please upload and extract a file in the 'Upload & Extract' tab first.")
with tab5:
st.header("📝 Quiz Yourself!")
doc_titles = list(extracted_texts.keys())
if doc_titles:
selected_doc_title_quiz = st.selectbox("Generate quiz from document:", doc_titles)
if selected_doc_title_quiz in extracted_texts:
text_for_quiz =extracted_texts[selected_doc_title_quiz]
if "quiz_questions" not in st.session_state:
st.session_state.quiz_questions = generate_quiz_questions(text_for_quiz)
if st.session_state.quiz_questions:
display_quiz(st.session_state.quiz_questions)
if st.session_state.quiz_submitted:
grade_quiz()
if st.button("Refresh Questions"):
st.session_state.quiz_questions = generate_quiz_questions(text_for_quiz)
st.session_state.quiz_submitted = False
st.session_state.user_answers = {}
st.rerun() # Force a re-render to show new questions
else:
st.info("Could not generate quiz questions from the current document.")
else:
st.warning(f"Original text for {selected_doc_title_quiz} not found. Please re-upload.")
else:
st.info("Please upload and extract a file in the 'Upload & Extract' tab first.") |