Madan commited on
Commit
c8f751a
·
1 Parent(s): 354f45f

Add application file

Browse files
Weight/sam_vit_b_01ec64.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383
Weight/sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
Weight/sam_vit_l_0b3195.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622
3
+ size 1249524607
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from diffusers import StableDiffusionInpaintPipeline
5
+ from PIL import Image
6
+ from segment_anything import SamPredictor, sam_model_registry
7
+
8
+
9
+ device="cpu"
10
+ sam_checkpoint = "Weight/sam_vit_h_4b8939.pth"
11
+ model_type = "vit_h"
12
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
13
+ sam.to(device)
14
+ predictor = SamPredictor(sam)
15
+
16
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
17
+ "stabilityai/stable-diffusion-2-inpainting",
18
+ torch_dtype=torch.float32
19
+ )
20
+
21
+ pipe = pipe.to(device)
22
+ selected_pixels = []
23
+
24
+ with gr.Blocks() as demo:
25
+ with gr.Row():
26
+ input_img = gr.Image(label="Input")
27
+ mask_img = gr.Image(label="Mas")
28
+ output_img = gr.Image(label="Output")
29
+ with gr.Blocks():
30
+ prompt_text = gr.Textbox(lines=1, label="Prompt")
31
+ with gr.Blocks():
32
+ submit = gr.Button("Submit")
33
+
34
+ def generate_mask(image, evt:gr.SelectData):
35
+
36
+ input_labels = np.ones(len(selected_pixels))
37
+ selected_pixels.append(evt.index)
38
+
39
+ predictor.set_image(image)
40
+ input_points = np.array(selected_pixels)
41
+
42
+ input_labels = np.ones(input_labels.shape[0])
43
+
44
+ mask, _, _ = predictor.predict(
45
+ point_coords= input_points,
46
+ point_labels= input_labels,
47
+ multimask_output=False
48
+ )
49
+ # (n, sz, sz)
50
+ mask = Image.fromarray(mask[0, : , :])
51
+ mask = mask.resize((512, 512)) # Resize the mask to (512, 512)
52
+ mask = np.expand_dims(mask, axis=2)
53
+ return mask
54
+
55
+ def inpaint(image, mask, prompt):
56
+ image = Image.fromarray(image)
57
+ mask = Image.fromarray(mask)
58
+
59
+ image = image.resize((512,512))
60
+ mask = mask.resize((512,512))
61
+
62
+ output = pipe(
63
+ prompt=prompt,
64
+ image=image,
65
+ mask_image=mask,
66
+ ).images[0]
67
+
68
+ return output
69
+
70
+ input_img.select(generate_mask, [input_img], [mask_img])
71
+ submit.click(
72
+ inpaint,
73
+ inputs=[input_img, mask_img, prompt_text],
74
+ outputs=[output_img],
75
+ )
76
+ if __name__ == "__main__":
77
+ demo.launch()