import os import tempfile from hezar.models import Model from hezar.utils import load_image, draw_boxes from transformers import TrOCRProcessor, VisionEncoderDecoderModel import gradio as gr import numpy as np from PIL import Image import io # Load models on CPU (Hugging Face Spaces default) craft_model = Model.load("hezarai/CRAFT", device="cpu") processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten') def recognize_handwritten_text(image): try: # Ensure image is a PIL image and convert to a compatible format if not isinstance(image, Image.Image): image = Image.fromarray(np.array(image)).convert("RGB") # Save the uploaded image to a temporary file in JPEG format with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file: image.save(tmp_file.name, format="JPEG") tmp_path = tmp_file.name # Load image with hezar utils using file path processed_image = load_image(tmp_path) # Ensure processed_image is in a compatible format (convert to NumPy if needed) if not isinstance(processed_image, np.ndarray): processed_image = np.array(Image.open(tmp_path)) # Detect text regions with CRAFT outputs = craft_model.predict(processed_image) if not outputs or "boxes" not in outputs[0]: return Image.fromarray(processed_image), "No text detected" boxes = outputs[0]["boxes"] print(f"Debug: Boxes structure = {boxes}") # Log the exact structure pil_image = Image.fromarray(processed_image) texts = [] # Handle box format (assuming [x, y, width, height] or [[x1, y1], [x2, y2]]) for box in boxes: if len(box) == 4: # [x, y, width, height] x, y, width, height = box x_min, y_min = x, y x_max, y_max = x + width, y + height elif len(box) == 2 and all(len(p) == 2 for p in box): # [[x1, y1], [x2, y2]] x1, y1 = box[0] x2, y2 = box[1] x_min, y_min = min(x1, x2), min(y1, y2) x_max, y_max = max(x1, x2), max(y1, y2) else: print(f"Debug: Skipping invalid box {box}") # Log invalid boxes continue crop = pil_image.crop((x_min, y_min, x_max, y_max)) pixel_values = processor(images=crop, return_tensors="pt").pixel_values generated_ids = trocr_model.generate(pixel_values) text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] texts.append(text) # Draw boxes on the image result_image = draw_boxes(processed_image, boxes) result_pil = Image.fromarray(result_image) # Join recognized texts text_data = " ".join(texts) if texts else "No text recognized" return result_pil, f"Recognized text: {text_data}" except Exception as e: return Image.fromarray(np.array(image)), f"Error: {str(e)}" finally: # Clean up temporary file if 'tmp_path' in locals(): os.unlink(tmp_path) # Create Gradio interface interface = gr.Interface( fn=recognize_handwritten_text, inputs=gr.Image(type="pil", label="Upload any image format"), outputs=[gr.Image(type="pil", label="Detected Text Image"), gr.Text(label="Recognized Text")], title="Handwritten Text Detection and Recognition", description="Upload an image in any format (JPEG, PNG, BMP, etc.) to detect and recognize handwritten text." ) # Launch the app interface.launch()