yuragoithf commited on
Commit
f518509
·
verified ·
1 Parent(s): 6b605c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -30
app.py CHANGED
@@ -1,51 +1,54 @@
1
- import torch
 
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
- from PIL import Image
4
- import cv2
5
- import numpy as np
6
- from craft_text_detector import Craft
7
  import gradio as gr
 
 
8
 
9
- # Force CPU usage
10
- torch.set_default_device('cpu')
11
-
12
- # Load model and processor
13
  processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
14
- model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
15
- craft = Craft(output_dir=None, crop_type="box", cuda=False)
16
 
17
- def recognize_handwritten(image):
18
- # Convert Gradio image to OpenCV format
19
  image_np = np.array(image)
20
- image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
 
 
 
 
 
21
 
22
- # Detect text regions with Craft
23
- result = craft.detect_text(image=image_cv)
24
- boxes = result["boxes"]
25
- pil_image = Image.fromarray(cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB))
26
  texts = []
27
 
 
28
  for box in boxes:
29
- crop = pil_image.crop([box[0][0], box[0][1], box[2][0], box[2][1]])
 
30
  pixel_values = processor(images=crop, return_tensors="pt").pixel_values
31
- generated_ids = model.generate(pixel_values)
32
  text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
  texts.append(text)
34
 
35
- text_data = " ".join(texts) if texts else "No text detected"
36
- return f"Recognized text: {text_data}"
 
 
 
 
 
37
 
38
  # Create Gradio interface
39
  interface = gr.Interface(
40
- fn=recognize_handwritten,
41
  inputs=gr.Image(type="pil"),
42
- outputs="text",
43
- title="Handwritten Text Recognition",
44
- description="Upload an image containing handwritten text to recognize it."
45
  )
46
 
47
  # Launch the app
48
- interface.launch()
49
-
50
- # Cleanup
51
- craft.unload_craftnet_model()
 
1
+ from hezar.models import Model
2
+ from hezar.utils import load_image, draw_boxes
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
 
 
 
 
4
  import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
 
8
+ # Load models on CPU (Hugging Face Spaces default)
9
+ craft_model = Model.load("hezarai/CRAFT", device="cpu")
 
 
10
  processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
11
+ trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
 
12
 
13
+ def recognize_handwritten_text(image):
14
+ # Convert Gradio image to format compatible with hezar
15
  image_np = np.array(image)
16
+ processed_image = load_image(image_np)
17
+
18
+ # Detect text regions with CRAFT
19
+ outputs = craft_model.predict(processed_image)
20
+ if not outputs or "boxes" not in outputs[0]:
21
+ return Image.fromarray(processed_image), "No text detected"
22
 
23
+ boxes = outputs[0]["boxes"]
24
+ pil_image = Image.fromarray(processed_image)
 
 
25
  texts = []
26
 
27
+ # Recognize text in each detected region
28
  for box in boxes:
29
+ x_min, y_min, x_max, y_max = box[0][0], box[0][1], box[2][0], box[2][1]
30
+ crop = pil_image.crop((x_min, y_min, x_max, y_max))
31
  pixel_values = processor(images=crop, return_tensors="pt").pixel_values
32
+ generated_ids = trocr_model.generate(pixel_values)
33
  text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
  texts.append(text)
35
 
36
+ # Draw boxes on the image
37
+ result_image = draw_boxes(processed_image, boxes)
38
+ result_pil = Image.fromarray(result_image)
39
+
40
+ # Join recognized texts
41
+ text_data = " ".join(texts) if texts else "No text recognized"
42
+ return result_pil, f"Recognized text: {text_data}"
43
 
44
  # Create Gradio interface
45
  interface = gr.Interface(
46
+ fn=recognize_handwritten_text,
47
  inputs=gr.Image(type="pil"),
48
+ outputs=[gr.Image(type="pil"), gr.Text()],
49
+ title="Handwritten Text Detection and Recognition",
50
+ description="Upload an image to detect and recognize handwritten text."
51
  )
52
 
53
  # Launch the app
54
+ interface.launch()