omniapi / main.py
banao-tech's picture
Update main.py
00bda1b verified
raw
history blame
4.95 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
import base64
import io
import os
import logging
from PIL import Image
import torch
# Existing imports
from utils import (
check_ocr_box,
get_yolo_model,
get_caption_model_processor,
get_som_labeled_img,
)
from transformers import AutoProcessor, AutoModelForCausalLM
# Configure logging
logging.basicConfig(level=logging.DEBUG) # Changed to DEBUG for more verbosity
logger = logging.getLogger(__name__)
# Load YOLO model
yolo_model = get_yolo_model(model_path="weights/best.pt")
# Handle device placement
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if str(device) == "cuda":
yolo_model = yolo_model.cuda()
else:
yolo_model = yolo_model.cpu()
# Load caption model and processor
try:
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
except Exception as e:
logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
)
caption_model_processor = {"processor": processor, "model": model}
logger.info("Finished loading models!!!")
app = FastAPI()
class ProcessResponse(BaseModel):
image: str # Base64 encoded image
parsed_content_list: str
label_coordinates: str
def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
image_save_path = "imgs/saved_image_demo.png"
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
image_input.save(image_save_path)
logger.info(f"Saved image for processing: {image_save_path}")
# Open image and prepare it for further processing
image = Image.open(image_save_path)
box_overlay_ratio = image.size[0] / 3200
draw_bbox_config = {
"text_scale": 0.8 * box_overlay_ratio,
"text_thickness": max(int(2 * box_overlay_ratio), 1),
"text_padding": max(int(3 * box_overlay_ratio), 1),
"thickness": max(int(3 * box_overlay_ratio), 1),
}
# OCR and YOLO box processing
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
image_save_path,
display_img=False,
output_bb_format="xyxy",
goal_filtering=None,
easyocr_args={"paragraph": False, "text_threshold": 0.9},
use_paddleocr=True,
)
text, ocr_bbox = ocr_bbox_rslt
# Process image and get result
try:
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold,
)
except Exception as e:
logger.error(f"Error during labeling and captioning: {e}")
raise
logger.info("Finished processing image with YOLO and captioning.")
# Convert the image to base64 string
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
parsed_content_list_str = "\n".join(parsed_content_list)
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return ProcessResponse(
image=img_str,
parsed_content_list=parsed_content_list_str,
label_coordinates=str(label_coordinates),
)
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
):
try:
contents = await image_file.read()
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
logger.info(f"Processing image: {image_file.filename}")
logger.info(f"Image size: {image_input.size}")
# Debugging the input image
if not image_input:
raise ValueError("Image input is empty or invalid.")
response = process(image_input, box_threshold, iou_threshold)
# Ensure the response contains an image
if not response.image:
raise ValueError("Empty image in response")
logger.info("Processing complete, returning response.")
return response
except Exception as e:
logger.error(f"Error processing image: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))