|
import random |
|
import spaces |
|
import torch |
|
import gradio as gr |
|
|
|
from modeling.dmm_pipeline import StableDiffusionDMMPipeline |
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
ckpt_path = "ckpt" |
|
snapshot_download(repo_id="MCG-NJU/DMM", local_dir=ckpt_path) |
|
|
|
pipe = StableDiffusionDMMPipeline.from_pretrained( |
|
ckpt_path, |
|
torch_dtype=torch.float16, |
|
use_safetensors=True |
|
) |
|
pipe.to("cuda") |
|
|
|
|
|
@spaces.GPU |
|
def generate(prompt: str, |
|
negative_prompt: str, |
|
model_id: int, |
|
seed: int = 1234, |
|
height: int = 512, |
|
width: int = 512, |
|
all: bool = True): |
|
if all: |
|
outputs = [] |
|
for i in range(pipe.unet.get_num_models()): |
|
output = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
width=width, |
|
height=height, |
|
num_inference_steps=25, |
|
guidance_scale=7, |
|
model_id=i, |
|
generator=torch.Generator().manual_seed(seed), |
|
).images[0] |
|
outputs.append(output) |
|
return outputs |
|
else: |
|
output = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
width=width, |
|
height=height, |
|
num_inference_steps=25, |
|
guidance_scale=7, |
|
model_id=int(model_id), |
|
generator=torch.Generator().manual_seed(seed), |
|
).images[0] |
|
return [output,] |
|
|
|
|
|
candidates = [ |
|
"0. [JuggernautReborn] realistic", |
|
"1. [MajicmixRealisticV7] realistic, Asia portrait", |
|
"2. [EpicRealismV5] realistic", |
|
"3. [RealisticVisionV5] realistic", |
|
"4. [MajicmixFantasyV3] animation", |
|
"5. [MinimalismV2] illustration", |
|
"6. [RealCartoon3dV17] cartoon 3d", |
|
"7. [AWPaintingV1.4] animation", |
|
] |
|
|
|
def main(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# DMM Demo |
|
The checkpoint is https://huggingface.co/MCG-NJU/DMM. |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Column(): |
|
model_id = gr.Dropdown(candidates, label="Model Index", type="index") |
|
all_check = gr.Checkbox(label="All (ignore the selection above)") |
|
prompt = gr.Textbox("portrait photo of a girl, long golden hair, flowers, best quality", label="Prompt") |
|
negative_prompt = gr.Textbox("worst quality,low quality,normal quality,lowres,watermark,nsfw", label="Negative Prompt") |
|
with gr.Row(): |
|
seed = gr.Number(0, label="Seed", precision=0, scale=3) |
|
update_seed_btn = gr.Button("🎲", scale=1) |
|
with gr.Row(): |
|
height = gr.Number(768, step=8, label="Height (suggest 512~768)") |
|
width = gr.Number(512, step=8, label="Width") |
|
submit_btn = gr.Button("Submit", variant="primary") |
|
output = gr.Gallery(label="images") |
|
|
|
submit_btn.click(generate, |
|
inputs=[prompt, negative_prompt, model_id, seed, height, width, all_check], |
|
outputs=[output]) |
|
update_seed_btn.click(lambda: random.randint(0, 1000000), |
|
outputs=[seed]) |
|
|
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|