from fastapi import FastAPI, File, UploadFile, HTTPException from pydantic import BaseModel import base64 import io import os import logging from PIL import Image, UnidentifiedImageError import torch import asyncio from utils import ( check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img, ) from transformers import AutoProcessor, AutoModelForCausalLM # Configure logging logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) # Load YOLO model yolo_model = get_yolo_model(model_path="weights/best.pt") # Handle device placement device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if str(device) == "cuda": yolo_model = yolo_model.cuda() else: yolo_model = yolo_model.cpu() # Load caption model and processor try: processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base", trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ).to("cuda") except Exception as e: logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.") model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ) caption_model_processor = {"processor": processor, "model": model} logger.info("Finished loading models!!!") # Initialize FastAPI app app = FastAPI() # Define a queue for request processing request_queue = asyncio.Queue() # Define a response model for the processed image class ProcessResponse(BaseModel): image: str # Base64 encoded image parsed_content_list: str label_coordinates: str # Define the async worker function async def worker(): """ Background worker to process tasks from the request queue sequentially. """ while True: task = await request_queue.get() # Get the next task from the queue try: await task # Process the task except Exception as e: logger.error(f"Error while processing task: {e}") finally: request_queue.task_done() # Mark the task as done # Start the worker when the application starts @app.on_event("startup") async def startup_event(): logger.info("Starting background worker...") asyncio.create_task(worker()) # Start the worker in the background # Define the process function async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse: """ Asynchronously processes an image using YOLO and caption models. """ try: # Define the save path and ensure the directory exists image_save_path = "imgs/saved_image_demo.png" os.makedirs(os.path.dirname(image_save_path), exist_ok=True) # Save the image image_input.save(image_save_path) logger.debug(f"Image saved to: {image_save_path}") # Perform YOLO and caption model inference box_overlay_ratio = image_input.size[0] / 3200 draw_bbox_config = { "text_scale": 0.8 * box_overlay_ratio, "text_thickness": max(int(2 * box_overlay_ratio), 1), "text_padding": max(int(3 * box_overlay_ratio), 1), "thickness": max(int(3 * box_overlay_ratio), 1), } ocr_bbox_rslt, is_goal_filtered = await asyncio.to_thread( check_ocr_box, image_save_path, display_img=False, output_bb_format="xyxy", goal_filtering=None, easyocr_args={"paragraph": False, "text_threshold": 0.9}, use_paddleocr=True, ) text, ocr_bbox = ocr_bbox_rslt dino_labled_img, label_coordinates, parsed_content_list = await asyncio.to_thread( get_som_labeled_img, image_save_path, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text, iou_threshold=iou_threshold, ) # Convert labeled image to base64 image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") # Join parsed content list parsed_content_list_str = "\n".join([str(item) for item in parsed_content_list]) return ProcessResponse( image=img_str, parsed_content_list=parsed_content_list_str, label_coordinates=str(label_coordinates), ) except Exception as e: logger.error(f"Error in process function: {e}") raise HTTPException(status_code=500, detail=f"Failed to process the image: {e}") # Define the process_image endpoint @app.post("/process_image", response_model=ProcessResponse) async def process_image( image_file: UploadFile = File(...), box_threshold: float = 0.05, iou_threshold: float = 0.1, ): try: # Read the image file contents = await image_file.read() try: image_input = Image.open(io.BytesIO(contents)).convert("RGB") except UnidentifiedImageError as e: logger.error(f"Unsupported image format: {e}") raise HTTPException(status_code=400, detail="Unsupported image format.") # Create a task for processing task = asyncio.create_task(process(image_input, box_threshold, iou_threshold)) # Add the task to the queue await request_queue.put(task) logger.info(f"Task added to queue. Current queue size: {request_queue.qsize()}") # Wait for the task to complete response = await task return response except HTTPException as he: raise he except Exception as e: logger.error(f"Error processing image: {e}") raise HTTPException(status_code=500, detail=f"Internal server error: {e}")#