Spaces:
Sleeping
Sleeping
File size: 4,945 Bytes
a1c932a 00bda1b a1c932a d03c47c a1c932a ea2ade6 a1c932a ea2ade6 a1c932a ea2ade6 00bda1b ea2ade6 3dc870c 00bda1b 3d30079 ea2ade6 beccd45 ea2ade6 beccd45 ea2ade6 29f706c ea2ade6 beccd45 ea2ade6 29f706c a1c932a ea2ade6 a1c932a d03c47c a1c932a beccd45 a1c932a 00bda1b a1c932a 712f1db a1c932a 00bda1b a1c932a d03c47c 00bda1b a1c932a d03c47c a1c932a beccd45 00bda1b beccd45 00bda1b beccd45 00bda1b beccd45 a1c932a 00bda1b beccd45 712f1db d03c47c 00bda1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
import base64
import io
import os
import logging
from PIL import Image
import torch
# Existing imports
from utils import (
check_ocr_box,
get_yolo_model,
get_caption_model_processor,
get_som_labeled_img,
)
from transformers import AutoProcessor, AutoModelForCausalLM
# Configure logging
logging.basicConfig(level=logging.DEBUG) # Changed to DEBUG for more verbosity
logger = logging.getLogger(__name__)
# Load YOLO model
yolo_model = get_yolo_model(model_path="weights/best.pt")
# Handle device placement
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if str(device) == "cuda":
yolo_model = yolo_model.cuda()
else:
yolo_model = yolo_model.cpu()
# Load caption model and processor
try:
processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-base", trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
).to("cuda")
except Exception as e:
logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float16,
trust_remote_code=True,
)
caption_model_processor = {"processor": processor, "model": model}
logger.info("Finished loading models!!!")
app = FastAPI()
class ProcessResponse(BaseModel):
image: str # Base64 encoded image
parsed_content_list: str
label_coordinates: str
def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
image_save_path = "imgs/saved_image_demo.png"
os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
image_input.save(image_save_path)
logger.info(f"Saved image for processing: {image_save_path}")
# Open image and prepare it for further processing
image = Image.open(image_save_path)
box_overlay_ratio = image.size[0] / 3200
draw_bbox_config = {
"text_scale": 0.8 * box_overlay_ratio,
"text_thickness": max(int(2 * box_overlay_ratio), 1),
"text_padding": max(int(3 * box_overlay_ratio), 1),
"thickness": max(int(3 * box_overlay_ratio), 1),
}
# OCR and YOLO box processing
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
image_save_path,
display_img=False,
output_bb_format="xyxy",
goal_filtering=None,
easyocr_args={"paragraph": False, "text_threshold": 0.9},
use_paddleocr=True,
)
text, ocr_bbox = ocr_bbox_rslt
# Process image and get result
try:
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold,
)
except Exception as e:
logger.error(f"Error during labeling and captioning: {e}")
raise
logger.info("Finished processing image with YOLO and captioning.")
# Convert the image to base64 string
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
parsed_content_list_str = "\n".join(parsed_content_list)
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return ProcessResponse(
image=img_str,
parsed_content_list=parsed_content_list_str,
label_coordinates=str(label_coordinates),
)
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
):
try:
contents = await image_file.read()
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
logger.info(f"Processing image: {image_file.filename}")
logger.info(f"Image size: {image_input.size}")
# Debugging the input image
if not image_input:
raise ValueError("Image input is empty or invalid.")
response = process(image_input, box_threshold, iou_threshold)
# Ensure the response contains an image
if not response.image:
raise ValueError("Empty image in response")
logger.info("Processing complete, returning response.")
return response
except Exception as e:
logger.error(f"Error processing image: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
|