yuragoithf commited on
Commit
e44b901
·
verified ·
1 Parent(s): 170cbbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -15,7 +15,11 @@ trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-ha
15
 
16
  def recognize_handwritten_text(image):
17
  try:
18
- # Save the uploaded image to a temporary file
 
 
 
 
19
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:
20
  image.save(tmp_file.name, format="JPEG")
21
  tmp_path = tmp_file.name
@@ -36,9 +40,16 @@ def recognize_handwritten_text(image):
36
  pil_image = Image.fromarray(processed_image)
37
  texts = []
38
 
39
- # Recognize text in each detected region
40
  for box in boxes:
41
- x_min, y_min, x_max, y_max = box[0][0], box[0][1], box[2][0], box[2][1]
 
 
 
 
 
 
 
42
  crop = pil_image.crop((x_min, y_min, x_max, y_max))
43
  pixel_values = processor(images=crop, return_tensors="pt").pixel_values
44
  generated_ids = trocr_model.generate(pixel_values)
@@ -63,10 +74,10 @@ def recognize_handwritten_text(image):
63
  # Create Gradio interface
64
  interface = gr.Interface(
65
  fn=recognize_handwritten_text,
66
- inputs=gr.Image(type="pil"),
67
- outputs=[gr.Image(type="pil"), gr.Text()],
68
  title="Handwritten Text Detection and Recognition",
69
- description="Upload an image to detect and recognize handwritten text."
70
  )
71
 
72
  # Launch the app
 
15
 
16
  def recognize_handwritten_text(image):
17
  try:
18
+ # Ensure image is a PIL image and convert to a compatible format
19
+ if not isinstance(image, Image.Image):
20
+ image = Image.fromarray(np.array(image)).convert("RGB")
21
+
22
+ # Save the uploaded image to a temporary file in JPEG format
23
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:
24
  image.save(tmp_file.name, format="JPEG")
25
  tmp_path = tmp_file.name
 
40
  pil_image = Image.fromarray(processed_image)
41
  texts = []
42
 
43
+ # Adjust box unpacking based on actual structure
44
  for box in boxes:
45
+ if len(box) >= 4: # Check if box has at least 4 coordinates
46
+ x_min, y_min, x_max, y_max = box[0][0], box[0][1], box[2][0], box[2][1]
47
+ elif len(box) == 2: # Handle case with only 2 points (e.g., center and size)
48
+ x_min, y_min = box[0][0] - box[1][0] / 2, box[0][1] - box[1][1] / 2
49
+ x_max, y_max = box[0][0] + box[1][0] / 2, box[0][1] + box[1][1] / 2
50
+ else:
51
+ continue # Skip invalid boxes
52
+
53
  crop = pil_image.crop((x_min, y_min, x_max, y_max))
54
  pixel_values = processor(images=crop, return_tensors="pt").pixel_values
55
  generated_ids = trocr_model.generate(pixel_values)
 
74
  # Create Gradio interface
75
  interface = gr.Interface(
76
  fn=recognize_handwritten_text,
77
+ inputs=gr.Image(type="pil", label="Upload any image format"),
78
+ outputs=[gr.Image(type="pil", label="Detected Text Image"), gr.Text(label="Recognized Text")],
79
  title="Handwritten Text Detection and Recognition",
80
+ description="Upload an image in any format (JPEG, PNG, BMP, etc.) to detect and recognize handwritten text."
81
  )
82
 
83
  # Launch the app