Spaces:
Running
Running
import gradio as gr | |
from pathlib import Path | |
import asyncio | |
import google.generativeai as genai | |
import os | |
import logging | |
from dotenv import load_dotenv | |
from typing import Optional, Tuple | |
from flashcard import FlashcardSet | |
from chat_agent import ( | |
chat_agent, | |
ChatDeps, | |
ChatResponse | |
) | |
# Load environment variables | |
load_dotenv() | |
genai.configure(api_key=os.environ["GEMINI_API_KEY"]) | |
async def process_message(message: dict, history: list, current_flashcards: Optional[FlashcardSet]) -> Tuple[str, list, Optional[FlashcardSet]]: | |
"""Process uploaded files and chat messages""" | |
# Get any text provided with the upload as system prompt | |
user_text = message.get("text", "").strip() | |
# Create chat dependencies | |
deps = ChatDeps( | |
message=user_text, | |
current_flashcards=current_flashcards | |
) | |
# Handle file uploads | |
if message.get("files"): | |
for file_path in message["files"]: | |
if file_path.endswith('.pdf'): | |
try: | |
with open(file_path, "rb") as pdf_file: | |
deps.pdf_data = pdf_file.read() | |
deps.system_prompt = user_text if user_text else None | |
# Let chat agent handle the PDF upload | |
result = await chat_agent.run("Process this PDF upload", deps=deps) | |
if result.data.should_generate_flashcards: | |
# Update current flashcards | |
current_flashcards = result.data.flashcards | |
history.append([ | |
f"Uploaded: {Path(file_path).name}" + | |
(f"\nWith instructions: {user_text}" if user_text else ""), | |
result.data.response | |
]) | |
return "", history, current_flashcards | |
except Exception as e: | |
error_msg = f"Error processing PDF: {str(e)}" | |
logging.error(error_msg) | |
history.append([f"Uploaded: {Path(file_path).name}", error_msg]) | |
return "", history, current_flashcards | |
else: | |
history.append([f"Uploaded: {Path(file_path).name}", "Please upload a PDF file."]) | |
return "", history, current_flashcards | |
# Handle text messages | |
if user_text: | |
try: | |
result = await chat_agent.run(user_text, deps=deps) | |
# Update flashcards if modified | |
if result.data.should_modify_flashcards: | |
current_flashcards = result.data.flashcards | |
history.append([user_text, result.data.response]) | |
return "", history, current_flashcards | |
except Exception as e: | |
error_msg = f"Error processing request: {str(e)}" | |
logging.error(error_msg) | |
history.append([user_text, error_msg]) | |
return "", history, current_flashcards | |
history.append(["", "Please upload a PDF file or send a message."]) | |
return "", history, current_flashcards | |
async def clear_chat(): | |
"""Reset the conversation and clear current flashcards""" | |
return None, None, None | |
# Create Gradio interface | |
with gr.Blocks(title="PDF Flashcard Generator") as demo: | |
gr.Markdown(""" | |
# π PDF Flashcard Generator | |
Upload a PDF document and get AI-generated flashcards to help you study! | |
You can provide custom instructions along with your PDF upload to guide the flashcard generation. | |
Powered by Google's Gemini AI | |
""") | |
chatbot = gr.Chatbot( | |
label="Flashcard Generation Chat", | |
bubble_full_width=False, | |
show_copy_button=True, | |
height=600 | |
) | |
# Session state for flashcards | |
current_flashcards = gr.State(value=None) | |
with gr.Row(): | |
chat_input = gr.MultimodalTextbox( | |
label="Upload PDF or type a message", | |
placeholder="Drop a PDF file here. You can also add instructions for how the flashcards should be generated...", | |
file_types=[".pdf", "application/pdf", "pdf"], | |
show_label=False, | |
sources=["upload"], | |
scale=20, | |
min_width=100 | |
) | |
clear_btn = gr.Button("ποΈ", variant="secondary", scale=1, min_width=50) | |
chat_input.submit( | |
fn=process_message, | |
inputs=[chat_input, chatbot, current_flashcards], | |
outputs=[chat_input, chatbot, current_flashcards] | |
) | |
clear_btn.click( | |
fn=clear_chat, | |
inputs=[], | |
outputs=[chat_input, chatbot, current_flashcards] | |
) | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO) | |
demo.launch( | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) | |