omniapi / main.py
banao-tech's picture
Update main.py
b89e6d8 verified
raw
history blame
5.83 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
import base64
import io
import os
import logging
from PIL import Image
import torch
import asyncio
from utils import (
check_ocr_box,
get_yolo_model,
get_caption_model_processor,
get_som_labeled_img,
)
from transformers import AutoProcessor, AutoModelForCausalLM
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Load YOLO model
yolo_model = get_yolo_model(model_path="weights/best.pt")
# Handle device placement
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if str(device) == "cuda":
yolo_model = yolo_model.cuda()
else:
yolo_model = yolo_model.cpu()
# Load caption model and processor
try:
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
except Exception as e:
logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
)
caption_model_processor = {"processor": processor, "model": model}
logger.info("Finished loading models!!!")
# Initialize FastAPI app
app = FastAPI()
# Define a queue for request processing
request_queue = asyncio.Queue()
# Define a response model for the processed image
class ProcessResponse(BaseModel):
image: str # Base64 encoded image
parsed_content_list: str
label_coordinates: str
# Define the async worker function
async def worker():
"""
Background worker to process tasks from the request queue sequentially.
"""
while True:
task = await request_queue.get() # Get the next task from the queue
try:
await task # Process the task
except Exception as e:
logger.error(f"Error while processing task: {e}")
finally:
request_queue.task_done() # Mark the task as done
# Start the worker when the application starts
@app.on_event("startup")
async def startup_event():
logger.info("Starting background worker...")
asyncio.create_task(worker()) # Start the worker in the background
# Define the process function
async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
"""
Asynchronously processes an image using YOLO and caption models.
"""
try:
# Define the save path and ensure the directory exists
image_save_path = "imgs/saved_image_demo.png"
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
# Save the image
image_input.save(image_save_path)
logger.debug(f"Image saved to: {image_save_path}")
# Perform YOLO and caption model inference
box_overlay_ratio = image_input.size[0] / 3200
draw_bbox_config = {
"text_scale": 0.8 * box_overlay_ratio,
"text_thickness": max(int(2 * box_overlay_ratio), 1),
"text_padding": max(int(3 * box_overlay_ratio), 1),
"thickness": max(int(3 * box_overlay_ratio), 1),
}
ocr_bbox_rslt, is_goal_filtered = await asyncio.to_thread(
check_ocr_box,
image_save_path,
display_img=False,
output_bb_format="xyxy",
goal_filtering=None,
easyocr_args={"paragraph": False, "text_threshold": 0.9},
use_paddleocr=True,
)
text, ocr_bbox = ocr_bbox_rslt
dino_labled_img, label_coordinates, parsed_content_list = await asyncio.to_thread(
get_som_labeled_img,
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold,
)
# Convert labeled image to base64
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Join parsed content list
parsed_content_list_str = "\n".join([str(item) for item in parsed_content_list])
return ProcessResponse(
image=img_str,
parsed_content_list=parsed_content_list_str,
label_coordinates=str(label_coordinates),
)
except Exception as e:
logger.error(f"Error in process function: {e}")
raise
# Define the process_image endpoint
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
):
try:
# Read the image file
contents = await image_file.read()
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
# Create a task for processing
task = asyncio.create_task(process(image_input, box_threshold, iou_threshold))
# Add the task to the queue
await request_queue.put(task)
logger.info(f"Task added to queue. Current queue size: {request_queue.qsize()}")
# Wait for the task to complete
response = await task
return response
except Exception as e:
logger.error(f"Error processing image: {e}")
raise HTTPException(status_code=500, detail=str(e))