banao-tech commited on
Commit
712f1db
·
verified ·
1 Parent(s): d03c47c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +7 -40
main.py CHANGED
@@ -1,5 +1,4 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from fastapi.responses import JSONResponse
3
  from pydantic import BaseModel
4
  import base64
5
  import io
@@ -17,31 +16,20 @@ from utils import (
17
  get_som_labeled_img,
18
  )
19
 
20
- # Import YOLO from ultralytics and transformers for captioning
21
  from ultralytics import YOLO
22
- from transformers import AutoProcessor, AutoModelForCausalLM
23
-
24
- # ---------------------------------------------------------------------------
25
- # Load the YOLO model
26
- # ---------------------------------------------------------------------------
27
- try:
28
- yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cuda", weights_only=False)["model"]
29
- yolo_model = yolo_model.to("cuda")
30
- except Exception as e:
31
- print("Error loading YOLO model on CUDA:", e)
32
- yolo_model = torch.load("weights/icon_detect/best.pt", map_location="cpu", weights_only=False)["model"]
33
 
 
 
34
  print(f"YOLO model type: {type(yolo_model)}")
35
 
36
- # ---------------------------------------------------------------------------
37
  # Load the captioning model (Florence-2)
38
- # ---------------------------------------------------------------------------
 
39
  device = "cuda" if torch.cuda.is_available() else "cpu"
40
  dtype = torch.float16 if device == "cuda" else torch.float32
41
 
42
- # Load the processor for the Florence-2 model
43
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
44
-
45
  try:
46
  model = AutoModelForCausalLM.from_pretrained(
47
  "weights/icon_caption_florence",
@@ -50,14 +38,12 @@ try:
50
  ).to(device)
51
  except Exception as e:
52
  print(f"Error loading caption model: {str(e)}")
53
- # Fallback to CPU with float32
54
  model = AutoModelForCausalLM.from_pretrained(
55
  "weights/icon_caption_florence",
56
  torch_dtype=torch.float32,
57
  trust_remote_code=True
58
  ).to("cpu")
59
 
60
- # Force configuration for DaViT vision tower if missing
61
  if not hasattr(model.config, 'vision_config'):
62
  model.config.vision_config = {}
63
  if 'model_type' not in model.config.vision_config:
@@ -66,9 +52,6 @@ if 'model_type' not in model.config.vision_config:
66
  caption_model_processor = {"processor": processor, "model": model}
67
  print("Finish loading caption model!")
68
 
69
- # ---------------------------------------------------------------------------
70
- # Create FastAPI application and response model
71
- # ---------------------------------------------------------------------------
72
  app = FastAPI()
73
 
74
  class ProcessResponse(BaseModel):
@@ -76,18 +59,13 @@ class ProcessResponse(BaseModel):
76
  parsed_content_list: str
77
  label_coordinates: str
78
 
79
- # ---------------------------------------------------------------------------
80
- # Main processing function
81
- # ---------------------------------------------------------------------------
82
  def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
83
- # Save the input image temporarily
84
  image_save_path = "imgs/saved_image_demo.png"
85
  os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
86
  image_input.save(image_save_path)
87
 
88
- # Open the saved image for processing
89
  image = Image.open(image_save_path)
90
- box_overlay_ratio = image.size[0] / 3200 # adjust scaling factor as needed
91
  draw_bbox_config = {
92
  "text_scale": 0.8 * box_overlay_ratio,
93
  "text_thickness": max(int(2 * box_overlay_ratio), 1),
@@ -95,7 +73,6 @@ def process(image_input: Image.Image, box_threshold: float, iou_threshold: float
95
  "thickness": max(int(3 * box_overlay_ratio), 1),
96
  }
97
 
98
- # Run OCR to get text and OCR bounding boxes
99
  ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
100
  image_save_path,
101
  display_img=False,
@@ -106,7 +83,6 @@ def process(image_input: Image.Image, box_threshold: float, iou_threshold: float
106
  )
107
  text, ocr_bbox = ocr_bbox_rslt
108
 
109
- # Run YOLO and semantic processing to get the labeled image and bounding boxes
110
  dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
111
  image_save_path,
112
  yolo_model,
@@ -118,13 +94,10 @@ def process(image_input: Image.Image, box_threshold: float, iou_threshold: float
118
  ocr_text=text,
119
  iou_threshold=iou_threshold,
120
  )
121
-
122
- # Decode the base64-encoded image output from get_som_labeled_img
123
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
124
  print("Finish processing")
125
  parsed_content_list_str = "\n".join(parsed_content_list)
126
 
127
- # Encode final image to base64 string for response
128
  buffered = io.BytesIO()
129
  image.save(buffered, format="PNG")
130
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
@@ -135,9 +108,6 @@ def process(image_input: Image.Image, box_threshold: float, iou_threshold: float
135
  label_coordinates=str(label_coordinates),
136
  )
137
 
138
- # ---------------------------------------------------------------------------
139
- # FastAPI endpoint for image processing
140
- # ---------------------------------------------------------------------------
141
  @app.post("/process_image", response_model=ProcessResponse)
142
  async def process_image(
143
  image_file: UploadFile = File(...),
@@ -148,13 +118,10 @@ async def process_image(
148
  contents = await image_file.read()
149
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
150
 
151
- # Debug logging for file information
152
  print(f"Processing image: {image_file.filename}")
153
  print(f"Image size: {image_input.size}")
154
 
155
  response = process(image_input, box_threshold, iou_threshold)
156
-
157
- # Validate response
158
  if not response.image:
159
  raise ValueError("Empty image in response")
160
 
@@ -162,5 +129,5 @@ async def process_image(
162
 
163
  except Exception as e:
164
  import traceback
165
- traceback.print_exc() # Print full traceback for debugging
166
  raise HTTPException(status_code=500, detail=str(e))
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
 
2
  from pydantic import BaseModel
3
  import base64
4
  import io
 
16
  get_som_labeled_img,
17
  )
18
 
19
+ # Load the YOLO model using the ultralytics class instead of torch.load
20
  from ultralytics import YOLO
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Use the YOLO constructor to load the model properly
23
+ yolo_model = YOLO("weights/icon_detect/best.pt")
24
  print(f"YOLO model type: {type(yolo_model)}")
25
 
 
26
  # Load the captioning model (Florence-2)
27
+ from transformers import AutoProcessor, AutoModelForCausalLM
28
+
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  dtype = torch.float16 if device == "cuda" else torch.float32
31
 
 
32
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
 
33
  try:
34
  model = AutoModelForCausalLM.from_pretrained(
35
  "weights/icon_caption_florence",
 
38
  ).to(device)
39
  except Exception as e:
40
  print(f"Error loading caption model: {str(e)}")
 
41
  model = AutoModelForCausalLM.from_pretrained(
42
  "weights/icon_caption_florence",
43
  torch_dtype=torch.float32,
44
  trust_remote_code=True
45
  ).to("cpu")
46
 
 
47
  if not hasattr(model.config, 'vision_config'):
48
  model.config.vision_config = {}
49
  if 'model_type' not in model.config.vision_config:
 
52
  caption_model_processor = {"processor": processor, "model": model}
53
  print("Finish loading caption model!")
54
 
 
 
 
55
  app = FastAPI()
56
 
57
  class ProcessResponse(BaseModel):
 
59
  parsed_content_list: str
60
  label_coordinates: str
61
 
 
 
 
62
  def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
 
63
  image_save_path = "imgs/saved_image_demo.png"
64
  os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
65
  image_input.save(image_save_path)
66
 
 
67
  image = Image.open(image_save_path)
68
+ box_overlay_ratio = image.size[0] / 3200
69
  draw_bbox_config = {
70
  "text_scale": 0.8 * box_overlay_ratio,
71
  "text_thickness": max(int(2 * box_overlay_ratio), 1),
 
73
  "thickness": max(int(3 * box_overlay_ratio), 1),
74
  }
75
 
 
76
  ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
77
  image_save_path,
78
  display_img=False,
 
83
  )
84
  text, ocr_bbox = ocr_bbox_rslt
85
 
 
86
  dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
87
  image_save_path,
88
  yolo_model,
 
94
  ocr_text=text,
95
  iou_threshold=iou_threshold,
96
  )
 
 
97
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
98
  print("Finish processing")
99
  parsed_content_list_str = "\n".join(parsed_content_list)
100
 
 
101
  buffered = io.BytesIO()
102
  image.save(buffered, format="PNG")
103
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
 
108
  label_coordinates=str(label_coordinates),
109
  )
110
 
 
 
 
111
  @app.post("/process_image", response_model=ProcessResponse)
112
  async def process_image(
113
  image_file: UploadFile = File(...),
 
118
  contents = await image_file.read()
119
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
120
 
 
121
  print(f"Processing image: {image_file.filename}")
122
  print(f"Image size: {image_input.size}")
123
 
124
  response = process(image_input, box_threshold, iou_threshold)
 
 
125
  if not response.image:
126
  raise ValueError("Empty image in response")
127
 
 
129
 
130
  except Exception as e:
131
  import traceback
132
+ traceback.print_exc()
133
  raise HTTPException(status_code=500, detail=str(e))