import torch from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor import gradio as gr from PIL import Image # Use a valid model identifier. Here we use "google/matcha-base". model_name = "google/matcha-base" # Load the pre-trained Pix2Struct model and processor model = Pix2StructForConditionalGeneration.from_pretrained(model_name) processor = Pix2StructProcessor.from_pretrained(model_name) # Move model to GPU if available and set to evaluation mode device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() def solve_math_problem(image): # Preprocess the image and include a prompt. inputs = processor(images=image, text="Solve the math problem:", return_tensors="pt") # Move all tensors to the same device as the model inputs = {key: value.to(device) for key, value in inputs.items()} # Generate the solution using beam search within a no_grad context with torch.no_grad(): predictions = model.generate( **inputs, max_new_tokens=150, # Increase this if longer answers are needed num_beams=5, # Beam search for more stable outputs early_stopping=True, temperature=0.5 # Lower temperature for more deterministic output ) # Decode the generated tokens to a string, skipping special tokens solution = processor.decode(predictions[0], skip_special_tokens=True) return solution # Set up the Gradio interface demo = gr.Interface( fn=solve_math_problem, inputs=gr.Image(type="pil", label="Upload Handwritten Math Problem"), outputs=gr.Textbox(label="Solution"), title="Handwritten Math Problem Solver", description="Upload an image of a handwritten math problem and the model will attempt to solve it.", theme="soft" ) if __name__ == "__main__": demo.launch()