from fastapi import FastAPI, File, UploadFile, HTTPException from pydantic import BaseModel import base64 import io import os import logging from PIL import Image import torch import asyncio from utils import ( check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img, ) from transformers import AutoProcessor, AutoModelForCausalLM # Configure logging logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) # Load YOLO model yolo_model = get_yolo_model(model_path="weights/best.pt") # Handle device placement device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if str(device) == "cuda": yolo_model = yolo_model.cuda() else: yolo_model = yolo_model.cpu() # Load caption model and processor try: processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base", trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ).to("cuda") except Exception as e: logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.") model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ) caption_model_processor = {"processor": processor, "model": model} logger.info("Finished loading models!!!") app = FastAPI() class ProcessResponse(BaseModel): image: str # Base64 encoded image parsed_content_list: str label_coordinates: str # Create a queue for sequential processing request_queue = asyncio.Queue() async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse: """ Asynchronously processes an image using YOLO and caption models. """ try: image_save_path = "imgs/saved_image_demo.png" os.makedirs(os.path.dirname(image_save_path), exist_ok=True) # Save the image asynchronously buffer = io.BytesIO() image_input.save(buffer, format="PNG") buffer.seek(0) # Perform YOLO and caption model inference box_overlay_ratio = image_input.size[0] / 3200 draw_bbox_config = { "text_scale": 0.8 * box_overlay_ratio, "text_thickness": max(int(2 * box_overlay_ratio), 1), "text_padding": max(int(3 * box_overlay_ratio), 1), "thickness": max(int(3 * box_overlay_ratio), 1), } ocr_bbox_rslt, is_goal_filtered = await asyncio.to_thread( check_ocr_box, image_save_path, display_img=False, output_bb_format="xyxy", goal_filtering=None, easyocr_args={"paragraph": False, "text_threshold": 0.9}, use_paddleocr=True, ) text, ocr_bbox = ocr_bbox_rslt dino_labled_img, label_coordinates, parsed_content_list = await asyncio.to_thread( get_som_labeled_img, image_save_path, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text, iou_threshold=iou_threshold, ) # Convert image to base64 image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") # Join parsed content list parsed_content_list_str = "\n".join([str(item) for item in parsed_content_list]) return ProcessResponse( image=img_str, parsed_content_list=parsed_content_list_str, label_coordinates=str(label_coordinates), ) except Exception as e: logger.error(f"Error in process function: {e}") raise @app.post("/process_image", response_model=ProcessResponse) async def process_image( image_file: UploadFile = File(...), box_threshold: float = 0.05, iou_threshold: float = 0.1, ): try: # Read the image file contents = await image_file.read() image_input = Image.open(io.BytesIO(contents)).convert("RGB") # Add the task to the queue task = asyncio.create_task( process(image_input, box_threshold, iou_threshold) ) await request_queue.put(task) # Process the next task in the queue task = await request_queue.get() response = await task request_queue.task_done() return response except Exception as e: logger.error(f"Error processing image: {e}") raise HTTPException(status_code=500, detail=str(e))