banao-tech commited on
Commit
9c616dc
·
verified ·
1 Parent(s): 99cd6de

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +51 -72
main.py CHANGED
@@ -1,71 +1,76 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
 
2
  from pydantic import BaseModel
 
3
  import base64
4
  import io
5
- import os
6
- import logging
7
  from PIL import Image
8
  import torch
 
 
9
 
10
  # Existing imports
 
 
 
 
 
11
  from utils import (
12
  check_ocr_box,
13
  get_yolo_model,
14
  get_caption_model_processor,
15
  get_som_labeled_img,
16
  )
17
- from transformers import AutoProcessor, AutoModelForCausalLM
18
 
19
- # Configure logging
20
- logging.basicConfig(level=logging.DEBUG) # Changed to DEBUG for more verbosity
21
- logger = logging.getLogger(__name__)
22
 
23
- # Load YOLO model
24
- yolo_model = get_yolo_model(model_path="weights/best.pt")
25
 
26
- # Handle device placement
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- if str(device) == "cuda":
29
- yolo_model = yolo_model.cuda()
30
- else:
31
- yolo_model = yolo_model.cpu()
 
 
 
 
 
 
 
32
 
33
- # Load caption model and processor
34
  try:
35
- processor = AutoProcessor.from_pretrained(
36
- "microsoft/Florence-2-base", trust_remote_code=True
37
- )
38
  model = AutoModelForCausalLM.from_pretrained(
39
  "weights/icon_caption_florence",
40
  torch_dtype=torch.float16,
41
  trust_remote_code=True,
42
  ).to("cuda")
43
- except Exception as e:
44
- logger.warning(f"Failed to load caption model on GPU: {e}. Falling back to CPU.")
45
  model = AutoModelForCausalLM.from_pretrained(
46
  "weights/icon_caption_florence",
47
  torch_dtype=torch.float16,
48
  trust_remote_code=True,
49
  )
50
-
51
  caption_model_processor = {"processor": processor, "model": model}
52
- logger.info("Finished loading models!!!")
53
 
54
  app = FastAPI()
55
 
 
56
  class ProcessResponse(BaseModel):
57
  image: str # Base64 encoded image
58
  parsed_content_list: str
59
  label_coordinates: str
60
 
61
- def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
 
 
 
62
  image_save_path = "imgs/saved_image_demo.png"
63
- os.makedirs(os.path.dirname(image_save_path), exist_ok=True)
64
  image_input.save(image_save_path)
65
-
66
- logger.info(f"Saved image for processing: {image_save_path}")
67
-
68
- # Open image and prepare it for further processing
69
  image = Image.open(image_save_path)
70
  box_overlay_ratio = image.size[0] / 3200
71
  draw_bbox_config = {
@@ -75,7 +80,6 @@ def process(image_input: Image.Image, box_threshold: float, iou_threshold: float
75
  "thickness": max(int(3 * box_overlay_ratio), 1),
76
  }
77
 
78
- # OCR and YOLO box processing
79
  ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
80
  image_save_path,
81
  display_img=False,
@@ -85,40 +89,33 @@ def process(image_input: Image.Image, box_threshold: float, iou_threshold: float
85
  use_paddleocr=True,
86
  )
87
  text, ocr_bbox = ocr_bbox_rslt
88
-
89
- # Process image and get result
90
- try:
91
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
92
- image_save_path,
93
- yolo_model,
94
- BOX_TRESHOLD=box_threshold,
95
- output_coord_in_ratio=True,
96
- ocr_bbox=ocr_bbox,
97
- draw_bbox_config=draw_bbox_config,
98
- caption_model_processor=caption_model_processor,
99
- ocr_text=text,
100
- iou_threshold=iou_threshold,
101
- )
102
- except Exception as e:
103
- logger.error(f"Error during labeling and captioning: {e}")
104
- raise
105
-
106
- logger.info("Finished processing image with YOLO and captioning.")
107
-
108
- # Convert the image to base64 string
109
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
 
110
  parsed_content_list_str = "\n".join(parsed_content_list)
111
 
 
112
  buffered = io.BytesIO()
113
  image.save(buffered, format="PNG")
114
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
115
 
116
  return ProcessResponse(
117
  image=img_str,
118
- parsed_content_list=parsed_content_list_str,
119
  label_coordinates=str(label_coordinates),
120
  )
121
 
 
122
  @app.post("/process_image", response_model=ProcessResponse)
123
  async def process_image(
124
  image_file: UploadFile = File(...),
@@ -128,26 +125,8 @@ async def process_image(
128
  try:
129
  contents = await image_file.read()
130
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
131
-
132
- logger.info(f"Processing image: {image_file.filename}")
133
- logger.info(f"Image size: {image_input.size}")
134
-
135
- # Debugging the input image
136
- if not image_input:
137
- raise ValueError("Image input is empty or invalid.")
138
-
139
- response = process(image_input, box_threshold, iou_threshold)
140
-
141
- # Ensure the response contains an image
142
- if not response.image:
143
- raise ValueError("Empty image in response")
144
-
145
- logger.info("Processing complete, returning response.")
146
- return response
147
-
148
  except Exception as e:
149
- logger.error(f"Error processing image: {e}")
150
- import traceback
151
- traceback.print_exc()
152
- raise HTTPException(status_code=500, detail=str(e))
153
 
 
 
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
  from pydantic import BaseModel
4
+ from typing import Optional
5
  import base64
6
  import io
 
 
7
  from PIL import Image
8
  import torch
9
+ import numpy as np
10
+ import os
11
 
12
  # Existing imports
13
+ import numpy as np
14
+ import torch
15
+ from PIL import Image
16
+ import io
17
+
18
  from utils import (
19
  check_ocr_box,
20
  get_yolo_model,
21
  get_caption_model_processor,
22
  get_som_labeled_img,
23
  )
24
+ import torch
25
 
26
+ # yolo_model = get_yolo_model(model_path='/data/icon_detect/best.pt')
27
+ # caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="/data/icon_caption_florence")
 
28
 
29
+ from ultralytics import YOLO
 
30
 
31
+ # if not os.path.exists("/data/icon_detect"):
32
+ # os.makedirs("/data/icon_detect")
33
+
34
+ try:
35
+ yolo_model = YOLO("weights/icon_detect/best.pt").to("cuda")
36
+ except:
37
+ yolo_model = YOLO("weights/icon_detect/best.pt")
38
+
39
+ from transformers import AutoProcessor, AutoModelForCausalLM
40
+
41
+ processor = AutoProcessor.from_pretrained(
42
+ "microsoft/Florence-2-base", trust_remote_code=True
43
+ )
44
 
 
45
  try:
 
 
 
46
  model = AutoModelForCausalLM.from_pretrained(
47
  "weights/icon_caption_florence",
48
  torch_dtype=torch.float16,
49
  trust_remote_code=True,
50
  ).to("cuda")
51
+ except:
 
52
  model = AutoModelForCausalLM.from_pretrained(
53
  "weights/icon_caption_florence",
54
  torch_dtype=torch.float16,
55
  trust_remote_code=True,
56
  )
 
57
  caption_model_processor = {"processor": processor, "model": model}
58
+ print("finish loading model!!!")
59
 
60
  app = FastAPI()
61
 
62
+
63
  class ProcessResponse(BaseModel):
64
  image: str # Base64 encoded image
65
  parsed_content_list: str
66
  label_coordinates: str
67
 
68
+
69
+ def process(
70
+ image_input: Image.Image, box_threshold: float, iou_threshold: float
71
+ ) -> ProcessResponse:
72
  image_save_path = "imgs/saved_image_demo.png"
 
73
  image_input.save(image_save_path)
 
 
 
 
74
  image = Image.open(image_save_path)
75
  box_overlay_ratio = image.size[0] / 3200
76
  draw_bbox_config = {
 
80
  "thickness": max(int(3 * box_overlay_ratio), 1),
81
  }
82
 
 
83
  ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
84
  image_save_path,
85
  display_img=False,
 
89
  use_paddleocr=True,
90
  )
91
  text, ocr_bbox = ocr_bbox_rslt
92
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
93
+ image_save_path,
94
+ yolo_model,
95
+ BOX_TRESHOLD=box_threshold,
96
+ output_coord_in_ratio=True,
97
+ ocr_bbox=ocr_bbox,
98
+ draw_bbox_config=draw_bbox_config,
99
+ caption_model_processor=caption_model_processor,
100
+ ocr_text=text,
101
+ iou_threshold=iou_threshold,
102
+ )
 
 
 
 
 
 
 
 
 
 
103
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
104
+ print("finish processing")
105
  parsed_content_list_str = "\n".join(parsed_content_list)
106
 
107
+ # Encode image to base64
108
  buffered = io.BytesIO()
109
  image.save(buffered, format="PNG")
110
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
111
 
112
  return ProcessResponse(
113
  image=img_str,
114
+ parsed_content_list=str(parsed_content_list_str),
115
  label_coordinates=str(label_coordinates),
116
  )
117
 
118
+
119
  @app.post("/process_image", response_model=ProcessResponse)
120
  async def process_image(
121
  image_file: UploadFile = File(...),
 
125
  try:
126
  contents = await image_file.read()
127
  image_input = Image.open(io.BytesIO(contents)).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  except Exception as e:
129
+ raise HTTPException(status_code=400, detail="Invalid image file")
 
 
 
130
 
131
+ response = process(image_input, box_threshold, iou_threshold)
132
+ return response