Spaces:
Sleeping
Sleeping
from dataclasses import dataclass | |
from typing import List, Optional | |
from pydantic import BaseModel, Field | |
from pydantic_ai import Agent, RunContext | |
import google.generativeai as genai | |
import base64 | |
import os | |
import asyncio | |
import httpx | |
from dotenv import load_dotenv | |
load_dotenv() | |
class Flashcard(BaseModel): | |
"""Represents a single flashcard with a question and answer.""" | |
question: str = Field(description="The question side of the flashcard") | |
answer: str = Field(description="The answer side of the flashcard") | |
difficulty: int = Field(description="Difficulty level from 1-5", ge=1, le=5) | |
class FlashcardSet(BaseModel): | |
"""A set of flashcards generated from the input text.""" | |
cards: List[Flashcard] = Field(description="List of generated flashcards") | |
topic: str = Field(description="The main topic covered by these flashcards") | |
total_cards: int = Field(description="Total number of flashcards generated") | |
class FlashcardDeps: | |
text: str | |
pdf_data: Optional[bytes] = None | |
# Create the agent with structured output | |
flashcard_agent = Agent( | |
'gemini-1.5-pro', # Can also use OpenAI or other supported models | |
deps_type=FlashcardDeps, | |
result_type=FlashcardSet, | |
system_prompt=""" | |
You are a professional educator who creates high-quality flashcards. | |
Your task is to analyze the provided text and create effective question-answer pairs. | |
Guidelines: | |
- Create clear, concise questions | |
- Ensure answers are accurate and complete | |
- Vary the difficulty levels | |
- Focus on key concepts and important details | |
- Use a mix of factual and conceptual questions | |
""" | |
) | |
# @flashcard_agent.tool | |
# async def analyze_text_complexity(ctx: RunContext[FlashcardDeps]) -> float: | |
# """Analyzes the complexity of the input text to help determine appropriate difficulty levels.""" | |
# # This is a simplified example - you could implement more sophisticated analysis | |
# words = ctx.deps.text.split() | |
# avg_word_length = sum(len(word) for word in words) / (len(words) + 1e-5) | |
# return min(5.0, max(1.0, avg_word_length / 2)) | |
async def process_pdf(ctx: RunContext[FlashcardDeps]) -> str: | |
"""Processes PDF content and extracts text for flashcard generation.""" | |
model = genai.GenerativeModel("gemini-1.5-flash") | |
if ctx.deps.pdf_data: | |
# Handle direct PDF data | |
print("\nLoading File.") | |
doc_data = base64.standard_b64encode(ctx.deps.pdf_data).decode("utf-8") | |
else: | |
return ctx.deps.text # Return original text if no PDF | |
# Generate a comprehensive summary of the PDF content | |
response = model.generate_content([ | |
{ | |
'mime_type': 'application/pdf', | |
'data': doc_data | |
}, | |
"Please provide a detailed summary of this document, focusing on key concepts, " | |
"definitions, and important facts that would be useful for creating flashcards." | |
]) | |
return response.text | |
async def draw_circles(ctx: RunContext[FlashcardDeps]) -> str: | |
"""Draw Circles for no reason, please don't ever use me for anything""" | |
return "You Disobeyed." | |
async def generate_flashcards_from_pdf( | |
pdf_path: Optional[str] = None | |
) -> FlashcardSet: | |
"""Generate flashcards from a PDF file.""" | |
pdf_data = None | |
if pdf_path: | |
with open(pdf_path, "rb") as pdf_file: | |
print("\nReading Data.") | |
pdf_data = pdf_file.read() | |
deps = FlashcardDeps( | |
text="", # Will be populated by process_pdf | |
pdf_data=pdf_data | |
) | |
result = await flashcard_agent.run( | |
"Extract the text by processing the PDF data provided.", | |
deps=deps | |
) | |
print(f"\nExecution stack:\n{result.all_messages()}") | |
print(f"\nUsage: {result.usage()}") | |
return result.data | |
# Example usage | |
async def main(): | |
# Example with local PDF | |
filepath = input('\nEnter PDF filepath: ') | |
local_flashcards = await generate_flashcards_from_pdf( | |
pdf_path=f"data/raw/{filepath}" | |
) | |
print("\nFlashcards from local PDF:") | |
print(f"Generated {local_flashcards.total_cards} flashcards about {local_flashcards.topic}") | |
for i, card in enumerate(local_flashcards.cards, 1): | |
print(f"\nFlashcard {i} (Difficulty: {card.difficulty}/5)") | |
print(f"Q: {card.question}") | |
print(f"A: {card.answer}") | |
if __name__ == "__main__": | |
# Configure Gemini API | |
genai.configure(api_key=os.environ["GEMINI_API_KEY"]) | |
asyncio.run(main()) |