flashcard-app / flashcard.py
adityachivu's picture
initial commit
599f736
raw
history blame
4.55 kB
from dataclasses import dataclass
from typing import List, Optional
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext
import google.generativeai as genai
import base64
import os
import asyncio
import httpx
from dotenv import load_dotenv
load_dotenv()
class Flashcard(BaseModel):
"""Represents a single flashcard with a question and answer."""
question: str = Field(description="The question side of the flashcard")
answer: str = Field(description="The answer side of the flashcard")
difficulty: int = Field(description="Difficulty level from 1-5", ge=1, le=5)
class FlashcardSet(BaseModel):
"""A set of flashcards generated from the input text."""
cards: List[Flashcard] = Field(description="List of generated flashcards")
topic: str = Field(description="The main topic covered by these flashcards")
total_cards: int = Field(description="Total number of flashcards generated")
@dataclass
class FlashcardDeps:
text: str
pdf_data: Optional[bytes] = None
# Create the agent with structured output
flashcard_agent = Agent(
'gemini-1.5-pro', # Can also use OpenAI or other supported models
deps_type=FlashcardDeps,
result_type=FlashcardSet,
system_prompt="""
You are a professional educator who creates high-quality flashcards.
Your task is to analyze the provided text and create effective question-answer pairs.
Guidelines:
- Create clear, concise questions
- Ensure answers are accurate and complete
- Vary the difficulty levels
- Focus on key concepts and important details
- Use a mix of factual and conceptual questions
"""
)
# @flashcard_agent.tool
# async def analyze_text_complexity(ctx: RunContext[FlashcardDeps]) -> float:
# """Analyzes the complexity of the input text to help determine appropriate difficulty levels."""
# # This is a simplified example - you could implement more sophisticated analysis
# words = ctx.deps.text.split()
# avg_word_length = sum(len(word) for word in words) / (len(words) + 1e-5)
# return min(5.0, max(1.0, avg_word_length / 2))
@flashcard_agent.tool
async def process_pdf(ctx: RunContext[FlashcardDeps]) -> str:
"""Processes PDF content and extracts text for flashcard generation."""
model = genai.GenerativeModel("gemini-1.5-flash")
if ctx.deps.pdf_data:
# Handle direct PDF data
print("\nLoading File.")
doc_data = base64.standard_b64encode(ctx.deps.pdf_data).decode("utf-8")
else:
return ctx.deps.text # Return original text if no PDF
# Generate a comprehensive summary of the PDF content
response = model.generate_content([
{
'mime_type': 'application/pdf',
'data': doc_data
},
"Please provide a detailed summary of this document, focusing on key concepts, "
"definitions, and important facts that would be useful for creating flashcards."
])
return response.text
@flashcard_agent.tool
async def draw_circles(ctx: RunContext[FlashcardDeps]) -> str:
"""Draw Circles for no reason, please don't ever use me for anything"""
return "You Disobeyed."
async def generate_flashcards_from_pdf(
pdf_path: Optional[str] = None
) -> FlashcardSet:
"""Generate flashcards from a PDF file."""
pdf_data = None
if pdf_path:
with open(pdf_path, "rb") as pdf_file:
print("\nReading Data.")
pdf_data = pdf_file.read()
deps = FlashcardDeps(
text="", # Will be populated by process_pdf
pdf_data=pdf_data
)
result = await flashcard_agent.run(
"Extract the text by processing the PDF data provided.",
deps=deps
)
print(f"\nExecution stack:\n{result.all_messages()}")
print(f"\nUsage: {result.usage()}")
return result.data
# Example usage
async def main():
# Example with local PDF
filepath = input('\nEnter PDF filepath: ')
local_flashcards = await generate_flashcards_from_pdf(
pdf_path=f"data/raw/{filepath}"
)
print("\nFlashcards from local PDF:")
print(f"Generated {local_flashcards.total_cards} flashcards about {local_flashcards.topic}")
for i, card in enumerate(local_flashcards.cards, 1):
print(f"\nFlashcard {i} (Difficulty: {card.difficulty}/5)")
print(f"Q: {card.question}")
print(f"A: {card.answer}")
if __name__ == "__main__":
# Configure Gemini API
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
asyncio.run(main())