from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel from typing import Optional import base64 import io from PIL import Image import torch import numpy as np import os # Existing imports import numpy as np import torch from PIL import Image import io from utils import ( check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img, ) import torch # yolo_model = get_yolo_model(model_path='/data/icon_detect/best.pt') # caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="/data/icon_caption_florence") from ultralytics import YOLO # if not os.path.exists("/data/icon_detect"): # os.makedirs("/data/icon_detect") try: yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cuda", weights_only=False)["model"] yolo_model = yolo_model.to("cuda") except: yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cpu", weights_only=False)["model"] from transformers import AutoProcessor, AutoModelForCausalLM processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-base", trust_remote_code=True ) try: model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ).to("cuda") except: model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True, ) caption_model_processor = {"processor": processor, "model": model} print("finish loading model!!!") app = FastAPI() class ProcessResponse(BaseModel): image: str # Base64 encoded image parsed_content_list: str label_coordinates: str def process( image_input: Image.Image, box_threshold: float, iou_threshold: float ) -> ProcessResponse: image_save_path = "imgs/saved_image_demo.png" os.makedirs(os.path.dirname(image_save_path), exist_ok=True) image_input.save(image_save_path) image = Image.open(image_save_path) box_overlay_ratio = image.size[0] / 3200 draw_bbox_config = { "text_scale": 0.8 * box_overlay_ratio, "text_thickness": max(int(2 * box_overlay_ratio), 1), "text_padding": max(int(3 * box_overlay_ratio), 1), "thickness": max(int(3 * box_overlay_ratio), 1), } ocr_bbox_rslt, is_goal_filtered = check_ocr_box( image_save_path, display_img=False, output_bb_format="xyxy", goal_filtering=None, easyocr_args={"paragraph": False, "text_threshold": 0.9}, use_paddleocr=True, ) text, ocr_bbox = ocr_bbox_rslt dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( image_save_path, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text, iou_threshold=iou_threshold, ) image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) print("finish processing") parsed_content_list_str = "\n".join(parsed_content_list) # Encode image to base64 buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return ProcessResponse( image=img_str, parsed_content_list=str(parsed_content_list_str), label_coordinates=str(label_coordinates), ) @app.post("/process_image", response_model=ProcessResponse) async def process_image( image_file: UploadFile = File(...), box_threshold: float = 0.05, iou_threshold: float = 0.1, ): try: contents = await image_file.read() image_input = Image.open(io.BytesIO(contents)).convert("RGB") # Add debug logging print(f"Processing image: {image_file.filename}") print(f"Image size: {image_input.size}") response = process(image_input, box_threshold, iou_threshold) # Validate response if not response.image: raise ValueError("Empty image in response") return response except Exception as e: import traceback traceback.print_exc() # This will show full error in logs raise HTTPException(status_code=500, detail=str(e))