from fastapi import FastAPI, File, UploadFile, HTTPException from pydantic import BaseModel import base64 import io import os import logging from PIL import Image import torch import asyncio # Import asyncio for asynchronous operations # Existing imports from utils import ( check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img, ) from transformers import AutoProcessor, AutoModelForCausalLM # Configure logging logging.basicConfig(level=logging.DEBUG) # Changed to DEBUG for more verbosity logger = logging.getLogger(__name__) # Load YOLO model yolo_model = get_yolo_model(model_path="weights/best.pt") # Handle device placement device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if str(device) == "cuda": yolo_model = yolo_model.cuda() else: yolo_model = yolo_model.cpu() # Load caption model and processor try: processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base", trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ).to("cuda") except Exception as e: logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.") model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ) caption_model_processor = {"processor": processor, "model": model} logger.info("Finished loading models!!!") app = FastAPI() class ProcessResponse(BaseModel): image: str # Base64 encoded image parsed_content_list: str label_coordinates: str async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse: image_save_path = "imgs/saved_image_demo.png" os.makedirs(os.path.dirname(image_save_path), exist_ok=True) # Save the image asynchronously loop = asyncio.get_event_loop() await loop.run_in_executor(None, image_input.save, image_save_path) logger.info(f"Saved image for processing: {image_save_path}") # Open image and prepare it for further processing image = Image.open(image_save_path) box_overlay_ratio = image.size[0] / 3200 draw_bbox_config = { "text_scale": 0.8 * box_overlay_ratio, "text_thickness": max(int(2 * box_overlay_ratio), 1), "text_padding": max(int(3 * box_overlay_ratio), 1), "thickness": max(int(3 * box_overlay_ratio), 1), } # OCR and YOLO box processing (run in a thread pool to avoid blocking the event loop) ocr_bbox_rslt, is_goal_filtered = await loop.run_in_executor( None, check_ocr_box, image_save_path, False, # display_img "xyxy", # output_bb_format None, # goal_filtering {"paragraph": False, "text_threshold": 0.9}, # easyocr_args True, # use_paddleocr ) text, ocr_bbox = ocr_bbox_rslt # Process image and get result (run in a thread pool) try: dino_labled_img, label_coordinates, parsed_content_list = await loop.run_in_executor( None, get_som_labeled_img, image_save_path, yolo_model, box_threshold, # BOX_TRESHOLD True, # output_coord_in_ratio ocr_bbox, # ocr_bbox draw_bbox_config, # draw_bbox_config caption_model_processor, # caption_model_processor text, # ocr_text iou_threshold, # iou_threshold ) except Exception as e: logger.error(f"Error during labeling and captioning: {e}") raise logger.info("Finished processing image with YOLO and captioning.") # Convert the image to base64 string image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) parsed_content_list_str = "\n".join(parsed_content_list) buffered = io.BytesIO() await loop.run_in_executor(None, image.save, buffered, "PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return ProcessResponse( image=img_str, parsed_content_list=parsed_content_list_str, label_coordinates=str(label_coordinates), ) @app.post("/process_image", response_model=ProcessResponse) async def process_image( image_file: UploadFile = File(...), box_threshold: float = 0.05, iou_threshold: float = 0.1, ): try: contents = await image_file.read() image_input = Image.open(io.BytesIO(contents)).convert("RGB") logger.info(f"Processing image: {image_file.filename}") logger.info(f"Image size: {image_input.size}") # Debugging the input image if not image_input: raise ValueError("Image input is empty or invalid.") response = await process(image_input, box_threshold, iou_threshold) # Ensure the response contains an image if not response.image: raise ValueError("Empty image in response") logger.info("Processing complete, returning response.") return response except Exception as e: logger.error(f"Error processing image: {e}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e))