snair94 commited on
Commit
717f2ec
·
verified ·
1 Parent(s): 27dcbe2

Create brief2.py

Browse files
Files changed (1) hide show
  1. brief2.py +74 -0
brief2.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
5
+ from matplotlib.colors import to_rgb
6
+ import re
7
+ import cv2
8
+
9
+ # Load model
10
+ processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
11
+ model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
12
+
13
+ def parse_color(color_str):
14
+ """
15
+ Converts a color string (hex, name, or rgba(...)) to an RGB tuple.
16
+ """
17
+ try:
18
+ if isinstance(color_str, str):
19
+ if color_str.startswith("rgba("):
20
+ # Extract the 3 RGB components
21
+ numbers = list(map(float, re.findall(r"[\d.]+", color_str)))
22
+ if len(numbers) >= 3:
23
+ r, g, b = numbers[:3]
24
+ return int(r), int(g), int(b)
25
+ else:
26
+ # Use named or hex color
27
+ return tuple(int(255 * c) for c in to_rgb(color_str))
28
+ except Exception:
29
+ pass
30
+ raise ValueError(f"Invalid color format: {color_str}. Use hex like '#ff0000', color name like 'red', or rgba format.")
31
+
32
+ def apply_mask(image: Image.Image, prompt: str, color: str) -> Image.Image:
33
+ # Process the input image and prompt
34
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
35
+ outputs = model(**inputs)
36
+ preds = outputs.logits[0]
37
+
38
+ # Get the binary mask from predictions
39
+ mask = preds.sigmoid().detach().cpu().numpy()
40
+ mask = (mask > 0.5).astype(np.uint8)
41
+
42
+ # Convert image to RGBA
43
+ image_np = np.array(image.convert("RGBA"))
44
+
45
+ # Resize mask to match image size
46
+ mask_resized = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]))
47
+ mask_3d = np.stack([mask_resized] * 4, axis=-1) # Extend mask to 3D
48
+
49
+ # Convert the color string to an RGB tuple
50
+ color_rgb = parse_color(color)
51
+ overlay_color = np.array([*color_rgb, 128], dtype=np.uint8) # RGBA with alpha 128
52
+
53
+ # Create an overlay with the selected color
54
+ overlay = np.zeros_like(image_np, dtype=np.uint8)
55
+ overlay[:] = overlay_color
56
+
57
+ # Apply the mask to the image
58
+ masked_image = np.where(mask_3d == 1, overlay, image_np)
59
+ return Image.fromarray(masked_image)
60
+
61
+ # Gradio Interface
62
+ iface = gr.Interface(
63
+ fn=apply_mask,
64
+ inputs=[
65
+ gr.Image(type="pil", label="Input Image"),
66
+ gr.Textbox(label="Segmentation Prompt", placeholder="e.g., helmet, road, sky"),
67
+ gr.ColorPicker(label="Mask Color", value="#ff0000")
68
+ ],
69
+ outputs=gr.Image(type="pil", label="Segmented Image"),
70
+ title="CLIPSeg Image Masking",
71
+ description="Upload an image, input a prompt (e.g., 'person', 'sky'), and pick a mask color."
72
+ )
73
+
74
+ iface.launch()