Spaces:
Runtime error
Runtime error
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
import gradio as gr | |
from PIL import Image | |
# Load the pre-trained Pix2Struct model and processor | |
model_name = "google/pix2struct-mathqa-base" | |
model = Pix2StructForConditionalGeneration.from_pretrained(model_name) | |
processor = Pix2StructProcessor.from_pretrained(model_name) | |
def solve_math_problem(image): | |
try: | |
# Preprocess the image | |
image = image.convert("RGB") # Ensure RGB format | |
inputs = processor( | |
images=[image], # Wrap in list | |
text="Solve the following math problem:", # More specific prompt | |
return_tensors="pt", | |
max_patches=2048, # Increased from default 1024 for better math handling | |
header_text="Math Problem" # Add header text | |
) | |
# Generate the solution | |
predictions = model.generate( | |
**inputs, | |
max_new_tokens=200, | |
early_stopping=True, | |
num_beams=4, | |
temperature=0.2 | |
) | |
# Decode the output | |
solution = processor.decode( | |
predictions[0], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
) | |
# Format the solution | |
return f"Problem: {processor.decode(inputs.input_ids[0])}\nSolution: {solution}" | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
# Gradio interface with explicit image handling | |
demo = gr.Interface( | |
fn=solve_math_problem, | |
inputs=gr.Image( | |
type="pil", | |
label="Upload Handwritten Math Problem", | |
image_mode="RGB", # Force RGB format | |
source="upload" | |
), | |
outputs=gr.Textbox(label="Solution", show_copy_button=True), | |
title="Handwritten Math Problem Solver", | |
description="Upload an image of a handwritten math problem (algebra, arithmetic, etc.) and get the solution", | |
examples=[ | |
["example_addition.png"], # Make sure to upload these files | |
["example_algebra.jpg"] | |
], | |
theme="soft", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch() |