Parthebhan commited on
Commit
4ad9d2a
·
verified ·
1 Parent(s): 5d7a51d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import spaces
3
  from huggingface_hub import hf_hub_download
4
  import os
 
5
 
6
  # Function to download models from Hugging Face
7
  def download_models(model_id):
@@ -14,37 +15,38 @@ def download_models(model_id):
14
 
15
 
16
  @spaces.GPU
17
- def yolov9_inference(img_path, model_id, image_size, conf_threshold, iou_threshold):
 
18
  """
19
  Load a YOLOv9 model, configure it, perform inference on an image, and optionally adjust
20
  the input size and apply test time augmentation.
21
 
22
- :param model_path: Path to the YOLOv9 model file.
23
  :param conf_threshold: Confidence threshold for NMS.
24
  :param iou_threshold: IoU threshold for NMS.
25
  :param img_path: Path to the image file.
26
- :param size: Optional, input size for inference.
27
- :return: A tuple containing the detections (boxes, scores, categories) and the results object for further actions like displaying.
28
  """
29
- # Import YOLOv9
30
- import yolov9
31
 
32
  # Load the model
33
  model_path = download_models(model_id)
34
- model = yolov9.load(model_path, device="cpu")
35
 
36
  # Set model parameters
37
  model.conf = conf_threshold
38
  model.iou = iou_threshold
39
 
40
  # Perform inference
41
- results = model(img_path)
42
 
43
  # Optionally, show detection bounding boxes on image
44
  output = results.render()
45
 
46
  return output[0]
47
 
 
48
 
49
 
50
 
 
2
  import spaces
3
  from huggingface_hub import hf_hub_download
4
  import os
5
+ import cv2 # Import OpenCV
6
 
7
  # Function to download models from Hugging Face
8
  def download_models(model_id):
 
15
 
16
 
17
  @spaces.GPU
18
+
19
+ def yolov9_inference(img_path, model_id, conf_threshold, iou_threshold):
20
  """
21
  Load a YOLOv9 model, configure it, perform inference on an image, and optionally adjust
22
  the input size and apply test time augmentation.
23
 
24
+ :param model_id: Identifier of the YOLOv9 model.
25
  :param conf_threshold: Confidence threshold for NMS.
26
  :param iou_threshold: IoU threshold for NMS.
27
  :param img_path: Path to the image file.
28
+ :return: Output image with detections.
 
29
  """
30
+ # Load the image from the file path
31
+ image = cv2.imread(img_path) # Use OpenCV to read the image
32
 
33
  # Load the model
34
  model_path = download_models(model_id)
35
+ model = yolov9.load(model_path, device="cpu") # Change device if needed
36
 
37
  # Set model parameters
38
  model.conf = conf_threshold
39
  model.iou = iou_threshold
40
 
41
  # Perform inference
42
+ results = model(image) # Pass the loaded image instead of the file path
43
 
44
  # Optionally, show detection bounding boxes on image
45
  output = results.render()
46
 
47
  return output[0]
48
 
49
+
50
 
51
 
52