yuragoithf commited on
Commit
7639be4
·
verified ·
1 Parent(s): 1e268f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -8
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import torch
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
 
 
 
4
  import gradio as gr
5
 
6
  # Force CPU usage
@@ -9,17 +12,28 @@ torch.set_default_device('cpu')
9
  # Load model and processor
10
  processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
11
  model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
 
12
 
13
  def recognize_handwritten(image):
14
- # Convert uploaded image to RGB
15
- image = image.convert("RGB")
16
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
17
 
18
- # Generate text
19
- generated_ids = model.generate(pixel_values)
20
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
21
 
22
- return f"Recognized text: {generated_text}"
 
 
 
 
 
 
 
 
23
 
24
  # Create Gradio interface
25
  interface = gr.Interface(
@@ -31,4 +45,7 @@ interface = gr.Interface(
31
  )
32
 
33
  # Launch the app
34
- interface.launch()
 
 
 
 
1
  import torch
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
+ import cv2
5
+ import numpy as np
6
+ from craft_text_detector import Craft
7
  import gradio as gr
8
 
9
  # Force CPU usage
 
12
  # Load model and processor
13
  processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
14
  model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
15
+ craft = Craft(output_dir=None, crop_type="box", cuda=False)
16
 
17
  def recognize_handwritten(image):
18
+ # Convert Gradio image to OpenCV format
19
+ image_np = np.array(image)
20
+ image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
21
 
22
+ # Detect text regions with Craft
23
+ result = craft.detect_text(image=image_cv)
24
+ boxes = result["boxes"]
25
+ pil_image = Image.fromarray(cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB))
26
+ texts = []
27
 
28
+ for box in boxes:
29
+ crop = pil_image.crop([box[0][0], box[0][1], box[2][0], box[2][1]])
30
+ pixel_values = processor(images=crop, return_tensors="pt").pixel_values
31
+ generated_ids = model.generate(pixel_values)
32
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
+ texts.append(text)
34
+
35
+ text_data = " ".join(texts) if texts else "No text detected"
36
+ return f"Recognized text: {text_data}"
37
 
38
  # Create Gradio interface
39
  interface = gr.Interface(
 
45
  )
46
 
47
  # Launch the app
48
+ interface.launch()
49
+
50
+ # Cleanup
51
+ craft.unload_craftnet_model()