omniapi / main.py
banao-tech's picture
Update main.py
7ecde71 verified
raw
history blame
4.92 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
import base64
import io
import os
import logging
from PIL import Image
import torch
import asyncio
from utils import (
check_ocr_box,
get_yolo_model,
get_caption_model_processor,
get_som_labeled_img,
)
from transformers import AutoProcessor, AutoModelForCausalLM
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Load YOLO model
yolo_model = get_yolo_model(model_path="weights/best.pt")
# Handle device placement
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if str(device) == "cuda":
yolo_model = yolo_model.cuda()
else:
yolo_model = yolo_model.cpu()
# Load caption model and processor
try:
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
except Exception as e:
logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
)
caption_model_processor = {"processor": processor, "model": model}
logger.info("Finished loading models!!!")
app = FastAPI()
class ProcessResponse(BaseModel):
image: str # Base64 encoded image
parsed_content_list: str
label_coordinates: str
# Create a queue for sequential processing
request_queue = asyncio.Queue()
async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
"""
Asynchronously processes an image using YOLO and caption models.
"""
try:
image_save_path = "imgs/saved_image_demo.png"
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
# Save the image asynchronously
buffer = io.BytesIO()
image_input.save(buffer, format="PNG")
buffer.seek(0)
# Perform YOLO and caption model inference
box_overlay_ratio = image_input.size[0] / 3200
draw_bbox_config = {
"text_scale": 0.8 * box_overlay_ratio,
"text_thickness": max(int(2 * box_overlay_ratio), 1),
"text_padding": max(int(3 * box_overlay_ratio), 1),
"thickness": max(int(3 * box_overlay_ratio), 1),
}
ocr_bbox_rslt, is_goal_filtered = await asyncio.to_thread(
check_ocr_box,
image_save_path,
display_img=False,
output_bb_format="xyxy",
goal_filtering=None,
easyocr_args={"paragraph": False, "text_threshold": 0.9},
use_paddleocr=True,
)
text, ocr_bbox = ocr_bbox_rslt
dino_labled_img, label_coordinates, parsed_content_list = await asyncio.to_thread(
get_som_labeled_img,
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold,
)
# Convert image to base64
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Join parsed content list
parsed_content_list_str = "\n".join([str(item) for item in parsed_content_list])
return ProcessResponse(
image=img_str,
parsed_content_list=parsed_content_list_str,
label_coordinates=str(label_coordinates),
)
except Exception as e:
logger.error(f"Error in process function: {e}")
raise
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
):
try:
# Read the image file
contents = await image_file.read()
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
# Add the task to the queue
task = asyncio.create_task(
process(image_input, box_threshold, iou_threshold)
)
await request_queue.put(task)
# Process the next task in the queue
task = await request_queue.get()
response = await task
request_queue.task_done()
return response
except Exception as e:
logger.error(f"Error processing image: {e}")
raise HTTPException(status_code=500, detail=str(e))