omniapi / main.py
banao-tech's picture
Update main.py
ab332bc verified
raw
history blame
5.37 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
import base64
import io
import os
import logging
from PIL import Image
import torch
import asyncio # Import asyncio for asynchronous operations
# Existing imports
from utils import (
check_ocr_box,
get_yolo_model,
get_caption_model_processor,
get_som_labeled_img,
)
from transformers import AutoProcessor, AutoModelForCausalLM
# Configure logging
logging.basicConfig(level=logging.DEBUG) # Changed to DEBUG for more verbosity
logger = logging.getLogger(__name__)
# Load YOLO model
yolo_model = get_yolo_model(model_path="weights/best.pt")
# Handle device placement
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if str(device) == "cuda":
yolo_model = yolo_model.cuda()
else:
yolo_model = yolo_model.cpu()
# Load caption model and processor
try:
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
except Exception as e:
logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
)
caption_model_processor = {"processor": processor, "model": model}
logger.info("Finished loading models!!!")
app = FastAPI()
class ProcessResponse(BaseModel):
image: str # Base64 encoded image
parsed_content_list: str
label_coordinates: str
async def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
image_save_path = "imgs/saved_image_demo.png"
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
# Save the image asynchronously
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, image_input.save, image_save_path)
logger.info(f"Saved image for processing: {image_save_path}")
# Open image and prepare it for further processing
image = Image.open(image_save_path)
box_overlay_ratio = image.size[0] / 3200
draw_bbox_config = {
"text_scale": 0.8 * box_overlay_ratio,
"text_thickness": max(int(2 * box_overlay_ratio), 1),
"text_padding": max(int(3 * box_overlay_ratio), 1),
"thickness": max(int(3 * box_overlay_ratio), 1),
}
# OCR and YOLO box processing (run in a thread pool to avoid blocking the event loop)
ocr_bbox_rslt, is_goal_filtered = await loop.run_in_executor(
None,
check_ocr_box,
image_save_path,
False, # display_img
"xyxy", # output_bb_format
None, # goal_filtering
{"paragraph": False, "text_threshold": 0.9}, # easyocr_args
True, # use_paddleocr
)
text, ocr_bbox = ocr_bbox_rslt
# Process image and get result (run in a thread pool)
try:
dino_labled_img, label_coordinates, parsed_content_list = await loop.run_in_executor(
None,
get_som_labeled_img,
image_save_path,
yolo_model,
box_threshold, # BOX_TRESHOLD
True, # output_coord_in_ratio
ocr_bbox, # ocr_bbox
draw_bbox_config, # draw_bbox_config
caption_model_processor, # caption_model_processor
text, # ocr_text
iou_threshold, # iou_threshold
)
except Exception as e:
logger.error(f"Error during labeling and captioning: {e}")
raise
logger.info("Finished processing image with YOLO and captioning.")
# Convert the image to base64 string
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
parsed_content_list_str = "\n".join(parsed_content_list)
buffered = io.BytesIO()
await loop.run_in_executor(None, image.save, buffered, "PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return ProcessResponse(
image=img_str,
parsed_content_list=parsed_content_list_str,
label_coordinates=str(label_coordinates),
)
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
):
try:
contents = await image_file.read()
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
logger.info(f"Processing image: {image_file.filename}")
logger.info(f"Image size: {image_input.size}")
# Debugging the input image
if not image_input:
raise ValueError("Image input is empty or invalid.")
response = await process(image_input, box_threshold, iou_threshold)
# Ensure the response contains an image
if not response.image:
raise ValueError("Empty image in response")
logger.info("Processing complete, returning response.")
return response
except Exception as e:
logger.error(f"Error processing image: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))