File size: 3,792 Bytes
c547456
 
f518509
 
a476151
 
f518509
 
170cbbb
a476151
f518509
 
1e268f8
f518509
a476151
f518509
e74b235
e44b901
 
 
 
 
c547456
 
 
e74b235
c547456
 
e74b235
170cbbb
 
 
 
e74b235
 
 
 
 
 
217a758
e74b235
 
 
217a758
e74b235
217a758
 
 
 
 
 
 
 
 
e44b901
217a758
 
e44b901
e74b235
 
 
 
 
 
 
 
 
 
 
 
 
f518509
e74b235
c547456
 
 
 
 
a476151
 
 
f518509
e44b901
 
f518509
e44b901
a476151
 
 
f518509
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import tempfile
from hezar.models import Model
from hezar.utils import load_image, draw_boxes
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import gradio as gr
import numpy as np
from PIL import Image
import io

# Load models on CPU (Hugging Face Spaces default)
craft_model = Model.load("hezarai/CRAFT", device="cpu")
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')

def recognize_handwritten_text(image):
    try:
        # Ensure image is a PIL image and convert to a compatible format
        if not isinstance(image, Image.Image):
            image = Image.fromarray(np.array(image)).convert("RGB")
        
        # Save the uploaded image to a temporary file in JPEG format
        with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:
            image.save(tmp_file.name, format="JPEG")
            tmp_path = tmp_file.name
        
        # Load image with hezar utils using file path
        processed_image = load_image(tmp_path)
        
        # Ensure processed_image is in a compatible format (convert to NumPy if needed)
        if not isinstance(processed_image, np.ndarray):
            processed_image = np.array(Image.open(tmp_path))
        
        # Detect text regions with CRAFT
        outputs = craft_model.predict(processed_image)
        if not outputs or "boxes" not in outputs[0]:
            return Image.fromarray(processed_image), "No text detected"
        
        boxes = outputs[0]["boxes"]
        print(f"Debug: Boxes structure = {boxes}")  # Log the exact structure
        pil_image = Image.fromarray(processed_image)
        texts = []
        
        # Handle box format (assuming [x, y, width, height] or [[x1, y1], [x2, y2]])
        for box in boxes:
            if len(box) == 4:  # [x, y, width, height]
                x, y, width, height = box
                x_min, y_min = x, y
                x_max, y_max = x + width, y + height
            elif len(box) == 2 and all(len(p) == 2 for p in box):  # [[x1, y1], [x2, y2]]
                x1, y1 = box[0]
                x2, y2 = box[1]
                x_min, y_min = min(x1, x2), min(y1, y2)
                x_max, y_max = max(x1, x2), max(y1, y2)
            else:
                print(f"Debug: Skipping invalid box {box}")  # Log invalid boxes
                continue
            
            crop = pil_image.crop((x_min, y_min, x_max, y_max))
            pixel_values = processor(images=crop, return_tensors="pt").pixel_values
            generated_ids = trocr_model.generate(pixel_values)
            text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
            texts.append(text)
        
        # Draw boxes on the image
        result_image = draw_boxes(processed_image, boxes)
        result_pil = Image.fromarray(result_image)
        
        # Join recognized texts
        text_data = " ".join(texts) if texts else "No text recognized"
        return result_pil, f"Recognized text: {text_data}"
    
    except Exception as e:
        return Image.fromarray(np.array(image)), f"Error: {str(e)}"
    finally:
        # Clean up temporary file
        if 'tmp_path' in locals():
            os.unlink(tmp_path)

# Create Gradio interface
interface = gr.Interface(
    fn=recognize_handwritten_text,
    inputs=gr.Image(type="pil", label="Upload any image format"),
    outputs=[gr.Image(type="pil", label="Detected Text Image"), gr.Text(label="Recognized Text")],
    title="Handwritten Text Detection and Recognition",
    description="Upload an image in any format (JPEG, PNG, BMP, etc.) to detect and recognize handwritten text."
)

# Launch the app
interface.launch()