Spaces:
Sleeping
Sleeping
File size: 3,792 Bytes
c547456 f518509 a476151 f518509 170cbbb a476151 f518509 1e268f8 f518509 a476151 f518509 e74b235 e44b901 c547456 e74b235 c547456 e74b235 170cbbb e74b235 217a758 e74b235 217a758 e74b235 217a758 e44b901 217a758 e44b901 e74b235 f518509 e74b235 c547456 a476151 f518509 e44b901 f518509 e44b901 a476151 f518509 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
import tempfile
from hezar.models import Model
from hezar.utils import load_image, draw_boxes
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import gradio as gr
import numpy as np
from PIL import Image
import io
# Load models on CPU (Hugging Face Spaces default)
craft_model = Model.load("hezarai/CRAFT", device="cpu")
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
def recognize_handwritten_text(image):
try:
# Ensure image is a PIL image and convert to a compatible format
if not isinstance(image, Image.Image):
image = Image.fromarray(np.array(image)).convert("RGB")
# Save the uploaded image to a temporary file in JPEG format
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:
image.save(tmp_file.name, format="JPEG")
tmp_path = tmp_file.name
# Load image with hezar utils using file path
processed_image = load_image(tmp_path)
# Ensure processed_image is in a compatible format (convert to NumPy if needed)
if not isinstance(processed_image, np.ndarray):
processed_image = np.array(Image.open(tmp_path))
# Detect text regions with CRAFT
outputs = craft_model.predict(processed_image)
if not outputs or "boxes" not in outputs[0]:
return Image.fromarray(processed_image), "No text detected"
boxes = outputs[0]["boxes"]
print(f"Debug: Boxes structure = {boxes}") # Log the exact structure
pil_image = Image.fromarray(processed_image)
texts = []
# Handle box format (assuming [x, y, width, height] or [[x1, y1], [x2, y2]])
for box in boxes:
if len(box) == 4: # [x, y, width, height]
x, y, width, height = box
x_min, y_min = x, y
x_max, y_max = x + width, y + height
elif len(box) == 2 and all(len(p) == 2 for p in box): # [[x1, y1], [x2, y2]]
x1, y1 = box[0]
x2, y2 = box[1]
x_min, y_min = min(x1, x2), min(y1, y2)
x_max, y_max = max(x1, x2), max(y1, y2)
else:
print(f"Debug: Skipping invalid box {box}") # Log invalid boxes
continue
crop = pil_image.crop((x_min, y_min, x_max, y_max))
pixel_values = processor(images=crop, return_tensors="pt").pixel_values
generated_ids = trocr_model.generate(pixel_values)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
texts.append(text)
# Draw boxes on the image
result_image = draw_boxes(processed_image, boxes)
result_pil = Image.fromarray(result_image)
# Join recognized texts
text_data = " ".join(texts) if texts else "No text recognized"
return result_pil, f"Recognized text: {text_data}"
except Exception as e:
return Image.fromarray(np.array(image)), f"Error: {str(e)}"
finally:
# Clean up temporary file
if 'tmp_path' in locals():
os.unlink(tmp_path)
# Create Gradio interface
interface = gr.Interface(
fn=recognize_handwritten_text,
inputs=gr.Image(type="pil", label="Upload any image format"),
outputs=[gr.Image(type="pil", label="Detected Text Image"), gr.Text(label="Recognized Text")],
title="Handwritten Text Detection and Recognition",
description="Upload an image in any format (JPEG, PNG, BMP, etc.) to detect and recognize handwritten text."
)
# Launch the app
interface.launch() |