sergiopaniego HF Staff commited on
Commit
9c9f88f
·
verified ·
1 Parent(s): 65083f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -7
app.py CHANGED
@@ -27,20 +27,40 @@ sam_model = SamModel.from_pretrained("facebook/sam-vit-base", device_map="auto",
27
  sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
28
 
29
  @spaces.GPU
30
- def predict_masks_and_scores(model, processor, raw_image, input_points=None, input_boxes=None):
31
  if input_boxes is not None:
32
  input_boxes = [input_boxes]
33
 
34
- inputs = processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
35
  original_sizes = inputs["original_sizes"]
36
  reshaped_sizes = inputs["reshaped_input_sizes"]
37
 
38
- inputs = inputs.to(model.device)
39
 
40
  with torch.no_grad():
41
- outputs = model(**inputs)
42
 
43
- masks = processor.image_processor.post_process_masks(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
45
  )
46
  scores = outputs.iou_scores
@@ -68,8 +88,8 @@ def process_inputs(prompts):
68
  input_points = [input_points] if input_points else None
69
  user_image = prompts['image']
70
 
71
- sam_masks, sam_scores = predict_masks_and_scores(sam_model, sam_processor, user_image, input_boxes=input_boxes, input_points=input_points)
72
- sam_hq_masks, sam_hq_scores = predict_masks_and_scores(sam_hq_model, sam_hq_processor, user_image, input_boxes=input_boxes, input_points=input_points)
73
 
74
  if input_boxes and input_points:
75
  img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM')
 
27
  sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
28
 
29
  @spaces.GPU
30
+ def predict_masks_and_scores_sam_hq(raw_image, input_points=None, input_boxes=None):
31
  if input_boxes is not None:
32
  input_boxes = [input_boxes]
33
 
34
+ inputs = sam_hq_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
35
  original_sizes = inputs["original_sizes"]
36
  reshaped_sizes = inputs["reshaped_input_sizes"]
37
 
38
+ inputs = inputs.to(sam_hq_model.device)
39
 
40
  with torch.no_grad():
41
+ outputs = sam_hq_model(**inputs)
42
 
43
+ masks = sam_hq_processor.image_processor.post_process_masks(
44
+ outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
45
+ )
46
+ scores = outputs.iou_scores
47
+ return masks, scores
48
+
49
+ @spaces.GPU
50
+ def predict_masks_and_scores_sam(raw_image, input_points=None, input_boxes=None):
51
+ if input_boxes is not None:
52
+ input_boxes = [input_boxes]
53
+
54
+ inputs = sam_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
55
+ original_sizes = inputs["original_sizes"]
56
+ reshaped_sizes = inputs["reshaped_input_sizes"]
57
+
58
+ inputs = inputs.to(sam_model.device)
59
+
60
+ with torch.no_grad():
61
+ outputs = sam_model(**inputs)
62
+
63
+ masks = sam_processor.image_processor.post_process_masks(
64
  outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
65
  )
66
  scores = outputs.iou_scores
 
88
  input_points = [input_points] if input_points else None
89
  user_image = prompts['image']
90
 
91
+ sam_masks, sam_scores = predict_masks_and_scores_sam(user_image, input_boxes=input_boxes, input_points=input_points)
92
+ sam_hq_masks, sam_hq_scores = predict_masks_and_scores_sam_hq(user_image, input_boxes=input_boxes, input_points=input_points)
93
 
94
  if input_boxes and input_points:
95
  img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM')