yuragoithf commited on
Commit
c547456
·
verified ·
1 Parent(s): e74b235

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from hezar.models import Model
2
  from hezar.utils import load_image, draw_boxes
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
@@ -12,13 +14,13 @@ trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-ha
12
 
13
  def recognize_handwritten_text(image):
14
  try:
15
- # Ensure image is a PIL image and convert to NumPy array
16
- if not isinstance(image, Image.Image):
17
- image = Image.fromarray(np.array(image)).convert("RGB")
18
- image_np = np.array(image)
19
 
20
- # Load image with hezar utils
21
- processed_image = load_image(image_np)
22
 
23
  # Detect text regions with CRAFT
24
  outputs = craft_model.predict(processed_image)
@@ -47,7 +49,11 @@ def recognize_handwritten_text(image):
47
  return result_pil, f"Recognized text: {text_data}"
48
 
49
  except Exception as e:
50
- return Image.fromarray(image_np), f"Error: {str(e)}"
 
 
 
 
51
 
52
  # Create Gradio interface
53
  interface = gr.Interface(
 
1
+ import os
2
+ import tempfile
3
  from hezar.models import Model
4
  from hezar.utils import load_image, draw_boxes
5
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
 
14
 
15
  def recognize_handwritten_text(image):
16
  try:
17
+ # Save the uploaded image to a temporary file
18
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:
19
+ image.save(tmp_file.name, format="JPEG")
20
+ tmp_path = tmp_file.name
21
 
22
+ # Load image with hezar utils using file path
23
+ processed_image = load_image(tmp_path)
24
 
25
  # Detect text regions with CRAFT
26
  outputs = craft_model.predict(processed_image)
 
49
  return result_pil, f"Recognized text: {text_data}"
50
 
51
  except Exception as e:
52
+ return Image.fromarray(np.array(image)), f"Error: {str(e)}"
53
+ finally:
54
+ # Clean up temporary file
55
+ if 'tmp_path' in locals():
56
+ os.unlink(tmp_path)
57
 
58
  # Create Gradio interface
59
  interface = gr.Interface(