Spaces:
Running
Running
from hezar.models import Model | |
from hezar.utils import load_image, draw_boxes | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
# Load models on CPU (Hugging Face Spaces default) | |
craft_model = Model.load("hezarai/CRAFT", device="cpu") | |
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') | |
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten') | |
def recognize_handwritten_text(image): | |
# Convert Gradio image to format compatible with hezar | |
image_np = np.array(image) | |
processed_image = load_image(image_np) | |
# Detect text regions with CRAFT | |
outputs = craft_model.predict(processed_image) | |
if not outputs or "boxes" not in outputs[0]: | |
return Image.fromarray(processed_image), "No text detected" | |
boxes = outputs[0]["boxes"] | |
pil_image = Image.fromarray(processed_image) | |
texts = [] | |
# Recognize text in each detected region | |
for box in boxes: | |
x_min, y_min, x_max, y_max = box[0][0], box[0][1], box[2][0], box[2][1] | |
crop = pil_image.crop((x_min, y_min, x_max, y_max)) | |
pixel_values = processor(images=crop, return_tensors="pt").pixel_values | |
generated_ids = trocr_model.generate(pixel_values) | |
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
texts.append(text) | |
# Draw boxes on the image | |
result_image = draw_boxes(processed_image, boxes) | |
result_pil = Image.fromarray(result_image) | |
# Join recognized texts | |
text_data = " ".join(texts) if texts else "No text recognized" | |
return result_pil, f"Recognized text: {text_data}" | |
# Create Gradio interface | |
interface = gr.Interface( | |
fn=recognize_handwritten_text, | |
inputs=gr.Image(type="pil"), | |
outputs=[gr.Image(type="pil"), gr.Text()], | |
title="Handwritten Text Detection and Recognition", | |
description="Upload an image to detect and recognize handwritten text." | |
) | |
# Launch the app | |
interface.launch() |