banao-tech commited on
Commit
d9307fe
·
verified ·
1 Parent(s): d0b9c8a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +23 -33
main.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from pydantic import BaseModel#
3
  import base64
4
  import io
5
  import os
@@ -24,10 +24,7 @@ yolo_model = get_yolo_model(model_path="weights/best.pt")
24
 
25
  # Handle device placement
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
- if str(device) == "cuda":
28
- yolo_model = yolo_model.cuda()
29
- else:
30
- yolo_model = yolo_model.cpu()
31
 
32
  # Load caption model and processor
33
  try:
@@ -38,7 +35,7 @@ try:
38
  "weights/icon_caption_florence",
39
  torch_dtype=torch.float16,
40
  trust_remote_code=True,
41
- ).to("cuda")
42
  except Exception as e:
43
  logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
44
  model = AutoModelForCausalLM.from_pretrained(
@@ -48,7 +45,7 @@ except Exception as e:
48
  )
49
 
50
  caption_model_processor = {"processor": processor, "model": model}
51
- logger.info("Finished loading models!!!")
52
 
53
  # Initialize FastAPI app
54
  app = FastAPI()
@@ -56,51 +53,44 @@ app = FastAPI()
56
  MAX_QUEUE_SIZE = 10 # Set a reasonable limit based on your system capacity
57
  request_queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
58
 
59
- # Define a response model for the processed image
60
  class ProcessResponse(BaseModel):
61
  image: str # Base64 encoded image
62
  parsed_content_list: str
63
  label_coordinates: str
64
 
65
 
66
- # Define the async worker function
67
  async def worker():
68
- """
69
- Background worker to process tasks from the request queue sequentially.
70
- """
71
  while True:
72
- task = await request_queue.get() # Get the next task from the queue
73
  try:
74
- await task # Process the task
75
  except Exception as e:
76
  logger.error(f"Error while processing task: {e}")
77
  finally:
78
- request_queue.task_done() # Mark the task as done
79
 
80
 
81
- # Start the worker when the application starts
82
  @app.on_event("startup")
83
  async def startup_event():
84
  logger.info("Starting background worker...")
85
-
86
- asyncio.create_task(worker()) # Start the worker in the background
87
 
88
 
89
- # Define the process function
90
  async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
91
- """
92
- Asynchronously processes an image using YOLO and caption models.
93
- """
94
  try:
95
- # Define the save path and ensure the directory exists
96
  image_save_path = "imgs/saved_image_demo.png"
97
  os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
98
 
99
- # Save the image
100
  image_input.save(image_save_path)
101
  logger.debug(f"Image saved to: {image_save_path}")
102
 
103
- # Perform YOLO and caption model inference
104
  box_overlay_ratio = image_input.size[0] / 3200
105
  draw_bbox_config = {
106
  "text_scale": 0.8 * box_overlay_ratio,
@@ -152,7 +142,7 @@ async def process(image_input: Image.Image, box_threshold: float, iou_threshold:
152
  raise HTTPException(status_code=500, detail=f"Failed to process the image: {e}")
153
 
154
 
155
- # Define the process_image endpoint
156
  @app.post("/process_image", response_model=ProcessResponse)
157
  async def process_image(
158
  image_file: UploadFile = File(...),
@@ -160,22 +150,22 @@ async def process_image(
160
  iou_threshold: float = 0.1,
161
  ):
162
  try:
163
- # Read the image file
164
  contents = await image_file.read()
165
  try:
166
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
167
- except UnidentifiedImageError as e:
168
- logger.error(f"Unsupported image format: {e}")
169
  raise HTTPException(status_code=400, detail="Unsupported image format.")
170
 
171
- # Create a task for processing
172
  task = asyncio.create_task(process(image_input, box_threshold, iou_threshold))
173
 
174
- # Add the task to the queue
175
  await request_queue.put(task)
176
  logger.info(f"Task added to queue. Current queue size: {request_queue.qsize()}")
177
 
178
- # Wait for the task to complete
179
  response = await task
180
 
181
  return response
@@ -183,4 +173,4 @@ async def process_image(
183
  raise he
184
  except Exception as e:
185
  logger.error(f"Error processing image: {e}")
186
- raise HTTPException(status_code=500, detail=f"Internal server error: {e}")#
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from pydantic import BaseModel
3
  import base64
4
  import io
5
  import os
 
24
 
25
  # Handle device placement
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ yolo_model = yolo_model.to(device)
 
 
 
28
 
29
  # Load caption model and processor
30
  try:
 
35
  "weights/icon_caption_florence",
36
  torch_dtype=torch.float16,
37
  trust_remote_code=True,
38
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
39
  except Exception as e:
40
  logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
41
  model = AutoModelForCausalLM.from_pretrained(
 
45
  )
46
 
47
  caption_model_processor = {"processor": processor, "model": model}
48
+ logger.info("Finished loading models!")
49
 
50
  # Initialize FastAPI app
51
  app = FastAPI()
 
53
  MAX_QUEUE_SIZE = 10 # Set a reasonable limit based on your system capacity
54
  request_queue = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
55
 
56
+ # Define response model
57
  class ProcessResponse(BaseModel):
58
  image: str # Base64 encoded image
59
  parsed_content_list: str
60
  label_coordinates: str
61
 
62
 
63
+ # Background worker to process queue tasks
64
  async def worker():
 
 
 
65
  while True:
66
+ task = await request_queue.get()
67
  try:
68
+ await task
69
  except Exception as e:
70
  logger.error(f"Error while processing task: {e}")
71
  finally:
72
+ request_queue.task_done()
73
 
74
 
75
+ # Start worker on startup
76
  @app.on_event("startup")
77
  async def startup_event():
78
  logger.info("Starting background worker...")
79
+ asyncio.create_task(worker())
 
80
 
81
 
82
+ # Image processing function
83
  async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
 
 
 
84
  try:
85
+ # Define save path
86
  image_save_path = "imgs/saved_image_demo.png"
87
  os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
88
 
89
+ # Save image
90
  image_input.save(image_save_path)
91
  logger.debug(f"Image saved to: {image_save_path}")
92
 
93
+ # YOLO and caption model inference
94
  box_overlay_ratio = image_input.size[0] / 3200
95
  draw_bbox_config = {
96
  "text_scale": 0.8 * box_overlay_ratio,
 
142
  raise HTTPException(status_code=500, detail=f"Failed to process the image: {e}")
143
 
144
 
145
+ # API endpoint for processing images
146
  @app.post("/process_image", response_model=ProcessResponse)
147
  async def process_image(
148
  image_file: UploadFile = File(...),
 
150
  iou_threshold: float = 0.1,
151
  ):
152
  try:
153
+ # Read image file
154
  contents = await image_file.read()
155
  try:
156
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
157
+ except UnidentifiedImageError:
158
+ logger.error("Unsupported image format.")
159
  raise HTTPException(status_code=400, detail="Unsupported image format.")
160
 
161
+ # Create processing task
162
  task = asyncio.create_task(process(image_input, box_threshold, iou_threshold))
163
 
164
+ # Add task to queue
165
  await request_queue.put(task)
166
  logger.info(f"Task added to queue. Current queue size: {request_queue.qsize()}")
167
 
168
+ # Wait for processing to complete
169
  response = await task
170
 
171
  return response
 
173
  raise he
174
  except Exception as e:
175
  logger.error(f"Error processing image: {e}")
176
+ raise HTTPException(status_code=500, detail=f"Internal server error: {e}")