banao-tech commited on
Commit
7ecde71
·
verified ·
1 Parent(s): 70f32bc

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +78 -75
main.py CHANGED
@@ -6,8 +6,7 @@ import os
6
  import logging
7
  from PIL import Image
8
  import torch
9
-
10
- # Existing imports
11
  from utils import (
12
  check_ocr_box,
13
  get_yolo_model,
@@ -17,7 +16,7 @@ from utils import (
17
  from transformers import AutoProcessor, AutoModelForCausalLM
18
 
19
  # Configure logging
20
- logging.basicConfig(level=logging.DEBUG) # Changed to DEBUG for more verbosity
21
  logger = logging.getLogger(__name__)
22
 
23
  # Load YOLO model
@@ -58,62 +57,72 @@ class ProcessResponse(BaseModel):
58
  parsed_content_list: str
59
  label_coordinates: str
60
 
61
- def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
62
- image_save_path = "imgs/saved_image_demo.png"
63
- os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
64
- image_input.save(image_save_path)
65
-
66
- image = Image.open(image_save_path)
67
- box_overlay_ratio = image.size[0] / 3200
68
- draw_bbox_config = {
69
- "text_scale": 0.8 * box_overlay_ratio,
70
- "text_thickness": max(int(2 * box_overlay_ratio), 1),
71
- "text_padding": max(int(3 * box_overlay_ratio), 1),
72
- "thickness": max(int(3 * box_overlay_ratio), 1),
73
- }
74
-
75
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
76
- image_save_path,
77
- display_img=False,
78
- output_bb_format="xyxy",
79
- goal_filtering=None,
80
- easyocr_args={"paragraph": False, "text_threshold": 0.9},
81
- use_paddleocr=True,
82
- )
83
- text, ocr_bbox = ocr_bbox_rslt
84
-
85
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
86
- image_save_path,
87
- yolo_model,
88
- BOX_TRESHOLD=box_threshold,
89
- output_coord_in_ratio=True,
90
- ocr_bbox=ocr_bbox,
91
- draw_bbox_config=draw_bbox_config,
92
- caption_model_processor=caption_model_processor,
93
- ocr_text=text,
94
- iou_threshold=iou_threshold,
95
- )
96
-
97
- # Log parsed_content_list to inspect its structure before joining
98
- logger.info(f"Parsed content list before join: {parsed_content_list}")
99
-
100
- # Ensure parsed_content_list is a list of strings, not dictionaries
101
- parsed_content_list_str = "\n".join([str(item) for item in parsed_content_list])
102
-
103
- image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
104
- print("Finish processing")
105
-
106
- # Convert the image to base64
107
- buffered = io.BytesIO()
108
- image.save(buffered, format="PNG")
109
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
110
-
111
- return ProcessResponse(
112
- image=img_str,
113
- parsed_content_list=parsed_content_list_str,
114
- label_coordinates=str(label_coordinates),
115
- )
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  @app.post("/process_image", response_model=ProcessResponse)
119
  async def process_image(
@@ -122,28 +131,22 @@ async def process_image(
122
  iou_threshold: float = 0.1,
123
  ):
124
  try:
 
125
  contents = await image_file.read()
126
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
127
 
128
- logger.info(f"Processing image: {image_file.filename}")
129
- logger.info(f"Image size: {image_input.size}")
130
-
131
- # Debugging the input image
132
- if not image_input:
133
- raise ValueError("Image input is empty or invalid.")
134
-
135
- response = process(image_input, box_threshold, iou_threshold)
136
 
137
- # Ensure the response contains an image
138
- if not response.image:
139
- raise ValueError("Empty image in response")
140
-
141
- logger.info("Processing complete, returning response.")
142
  return response
143
-
144
  except Exception as e:
145
  logger.error(f"Error processing image: {e}")
146
- import traceback
147
- traceback.print_exc()
148
  raise HTTPException(status_code=500, detail=str(e))
149
-
 
6
  import logging
7
  from PIL import Image
8
  import torch
9
+ import asyncio
 
10
  from utils import (
11
  check_ocr_box,
12
  get_yolo_model,
 
16
  from transformers import AutoProcessor, AutoModelForCausalLM
17
 
18
  # Configure logging
19
+ logging.basicConfig(level=logging.DEBUG)
20
  logger = logging.getLogger(__name__)
21
 
22
  # Load YOLO model
 
57
  parsed_content_list: str
58
  label_coordinates: str
59
 
60
+ # Create a queue for sequential processing
61
+ request_queue = asyncio.Queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
64
+ """
65
+ Asynchronously processes an image using YOLO and caption models.
66
+ """
67
+ try:
68
+ image_save_path = "imgs/saved_image_demo.png"
69
+ os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
70
+
71
+ # Save the image asynchronously
72
+ buffer = io.BytesIO()
73
+ image_input.save(buffer, format="PNG")
74
+ buffer.seek(0)
75
+
76
+ # Perform YOLO and caption model inference
77
+ box_overlay_ratio = image_input.size[0] / 3200
78
+ draw_bbox_config = {
79
+ "text_scale": 0.8 * box_overlay_ratio,
80
+ "text_thickness": max(int(2 * box_overlay_ratio), 1),
81
+ "text_padding": max(int(3 * box_overlay_ratio), 1),
82
+ "thickness": max(int(3 * box_overlay_ratio), 1),
83
+ }
84
+
85
+ ocr_bbox_rslt, is_goal_filtered = await asyncio.to_thread(
86
+ check_ocr_box,
87
+ image_save_path,
88
+ display_img=False,
89
+ output_bb_format="xyxy",
90
+ goal_filtering=None,
91
+ easyocr_args={"paragraph": False, "text_threshold": 0.9},
92
+ use_paddleocr=True,
93
+ )
94
+ text, ocr_bbox = ocr_bbox_rslt
95
+
96
+ dino_labled_img, label_coordinates, parsed_content_list = await asyncio.to_thread(
97
+ get_som_labeled_img,
98
+ image_save_path,
99
+ yolo_model,
100
+ BOX_TRESHOLD=box_threshold,
101
+ output_coord_in_ratio=True,
102
+ ocr_bbox=ocr_bbox,
103
+ draw_bbox_config=draw_bbox_config,
104
+ caption_model_processor=caption_model_processor,
105
+ ocr_text=text,
106
+ iou_threshold=iou_threshold,
107
+ )
108
+
109
+ # Convert image to base64
110
+ image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
111
+ buffered = io.BytesIO()
112
+ image.save(buffered, format="PNG")
113
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
114
+
115
+ # Join parsed content list
116
+ parsed_content_list_str = "\n".join([str(item) for item in parsed_content_list])
117
+
118
+ return ProcessResponse(
119
+ image=img_str,
120
+ parsed_content_list=parsed_content_list_str,
121
+ label_coordinates=str(label_coordinates),
122
+ )
123
+ except Exception as e:
124
+ logger.error(f"Error in process function: {e}")
125
+ raise
126
 
127
  @app.post("/process_image", response_model=ProcessResponse)
128
  async def process_image(
 
131
  iou_threshold: float = 0.1,
132
  ):
133
  try:
134
+ # Read the image file
135
  contents = await image_file.read()
136
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
137
 
138
+ # Add the task to the queue
139
+ task = asyncio.create_task(
140
+ process(image_input, box_threshold, iou_threshold)
141
+ )
142
+ await request_queue.put(task)
 
 
 
143
 
144
+ # Process the next task in the queue
145
+ task = await request_queue.get()
146
+ response = await task
147
+ request_queue.task_done()
148
+
149
  return response
 
150
  except Exception as e:
151
  logger.error(f"Error processing image: {e}")
 
 
152
  raise HTTPException(status_code=500, detail=str(e))