Spaces:
Runtime error
Runtime error
File size: 2,110 Bytes
c8f751a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
import numpy as np
import torch
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
from segment_anything import SamPredictor, sam_model_registry
device="cpu"
sam_checkpoint = "Weight/sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device)
predictor = SamPredictor(sam)
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32
)
pipe = pipe.to(device)
selected_pixels = []
with gr.Blocks() as demo:
with gr.Row():
input_img = gr.Image(label="Input")
mask_img = gr.Image(label="Mas")
output_img = gr.Image(label="Output")
with gr.Blocks():
prompt_text = gr.Textbox(lines=1, label="Prompt")
with gr.Blocks():
submit = gr.Button("Submit")
def generate_mask(image, evt:gr.SelectData):
input_labels = np.ones(len(selected_pixels))
selected_pixels.append(evt.index)
predictor.set_image(image)
input_points = np.array(selected_pixels)
input_labels = np.ones(input_labels.shape[0])
mask, _, _ = predictor.predict(
point_coords= input_points,
point_labels= input_labels,
multimask_output=False
)
# (n, sz, sz)
mask = Image.fromarray(mask[0, : , :])
mask = mask.resize((512, 512)) # Resize the mask to (512, 512)
mask = np.expand_dims(mask, axis=2)
return mask
def inpaint(image, mask, prompt):
image = Image.fromarray(image)
mask = Image.fromarray(mask)
image = image.resize((512,512))
mask = mask.resize((512,512))
output = pipe(
prompt=prompt,
image=image,
mask_image=mask,
).images[0]
return output
input_img.select(generate_mask, [input_img], [mask_img])
submit.click(
inpaint,
inputs=[input_img, mask_img, prompt_text],
outputs=[output_img],
)
if __name__ == "__main__":
demo.launch()
|