omniapi / app.py
banao-tech's picture
Rename main.py to app.py
d3c30f4 verified
raw
history blame
4.32 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
import base64
import io
import os
from PIL import Image
import torch
import numpy as np
# Import your custom utility functions
from utils import (
check_ocr_box,
get_yolo_model,
get_caption_model_processor,
get_som_labeled_img,
)
# Load the YOLO model using the ultralytics class instead of torch.load
from ultralytics import YOLO
# Use the YOLO constructor to load the model properly
yolo_model = YOLO("weights/icon_detect/best.pt")
print(f"YOLO model type: {type(yolo_model)}")
# Load the captioning model (Florence-2)
from transformers import AutoProcessor, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
try:
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=dtype,
trust_remote_code=True
).to(device)
except Exception as e:
print(f"Error loading caption model: {str(e)}")
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float32,
trust_remote_code=True
).to("cpu")
if not hasattr(model.config, 'vision_config'):
model.config.vision_config = {}
if 'model_type' not in model.config.vision_config:
model.config.vision_config['model_type'] = 'davit'
caption_model_processor = {"processor": processor, "model": model}
print("Finish loading caption model!")
app = FastAPI()
class ProcessResponse(BaseModel):
image: str # Base64 encoded image
parsed_content_list: str
label_coordinates: str
def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
image_save_path = "imgs/saved_image_demo.png"
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
image_input.save(image_save_path)
image = Image.open(image_save_path)
box_overlay_ratio = image.size[0] / 3200
draw_bbox_config = {
"text_scale": 0.8 * box_overlay_ratio,
"text_thickness": max(int(2 * box_overlay_ratio), 1),
"text_padding": max(int(3 * box_overlay_ratio), 1),
"thickness": max(int(3 * box_overlay_ratio), 1),
}
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
image_save_path,
display_img=False,
output_bb_format="xyxy",
goal_filtering=None,
easyocr_args={"paragraph": False, "text_threshold": 0.9},
use_paddleocr=True,
)
text, ocr_bbox = ocr_bbox_rslt
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold,
)
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
print("Finish processing")
parsed_content_list_str = "\n".join(parsed_content_list)
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return ProcessResponse(
image=img_str,
parsed_content_list=parsed_content_list_str,
label_coordinates=str(label_coordinates),
)
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
):
try:
contents = await image_file.read()
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
print(f"Processing image: {image_file.filename}")
print(f"Image size: {image_input.size}")
response = process(image_input, box_threshold, iou_threshold)
if not response.image:
raise ValueError("Empty image in response")
return response
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))