liujie31
add space gpy
a1f262a
import gradio as gr
import numpy as np
import random
from PIL import Image
import os
import spaces
from diffusers import StableDiffusion3Pipeline
import torch
from peft import PeftModel
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "frankjoshua/stable-diffusion-3.5-medium"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
pipe = StableDiffusion3Pipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
lora_models = {
"None": None,
"GenEval": "jieliu/SD3.5M-FlowGRPO-GenEval",
"Text Rendering": "jieliu/SD3.5M-FlowGRPO-Text",
"Human Prefer": "jieliu/SD3.5M-FlowGRPO-PickScore",
}
lora_prompts = {
"GenEval": os.path.join(os.getcwd(), "prompts/geneval.txt"),
"Text Rendering": os.path.join(os.getcwd(), "prompts/ocr.txt"),
"Human Prefer": os.path.join(os.getcwd(), "prompts/pickscore.txt"),
}
pipe.transformer = PeftModel.from_pretrained(pipe.transformer, lora_models["GenEval"], adapter_name="GenEval")
pipe.transformer.load_adapter(lora_models["Text Rendering"], adapter_name="Text Rendering")
pipe.transformer.load_adapter(lora_models["Human Prefer"], adapter_name="Human Prefer")
pipe = pipe.to(device)
# COUNTER_FILE = os.path.join(os.getcwd(),"model_call_counter.txt")
COUNTER_FILE = os.path.join("/data/model_call_counter.txt")
def get_call_count():
if not os.path.exists(COUNTER_FILE):
return 0
try:
with open(COUNTER_FILE, 'r') as f:
return int(f.read().strip())
except:
return 0
def update_call_count():
count = get_call_count() + 1
with open(COUNTER_FILE, 'w') as f:
f.write(str(count))
return count
def sample_prompt(lora_model):
if lora_model in lora_models and lora_model != "None":
file_path = f"{lora_prompts[lora_model]}"
try:
with open(file_path, 'r') as file:
prompts = file.readlines()
if lora_model=='GenEval':
total_lines = len(prompts)
if total_lines > 0:
weights = [1/(i+1) for i in range(total_lines)]
sum_weights = sum(weights)
normalized_weights = [w/sum_weights for w in weights]
return random.choices(prompts, weights=normalized_weights, k=1)[0].strip()
return "No prompts found in file."
else:
return random.choice(prompts).strip()
except FileNotFoundError:
return "Prompt file not found."
return ""
def create_grid_image(images):
# Create a 2x2 grid from the 4 images
width, height = images[0].size
grid_image = Image.new('RGB', (width * 2, height * 2))
# Paste images in a 2x2 grid
grid_image.paste(images[0], (0, 0))
grid_image.paste(images[1], (width, 0))
grid_image.paste(images[2], (0, height))
grid_image.paste(images[3], (width, height))
return grid_image
@spaces.GPU
def infer(
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
lora_model,
progress=gr.Progress(track_tqdm=True),
):
call_count = update_call_count()
images = []
seeds = []
# Generate 4 images
for i in range(4):
if randomize_seed:
current_seed = random.randint(0, MAX_SEED)
else:
current_seed = seed + i # Use sequential seeds if not randomizing
seeds.append(current_seed)
generator = torch.Generator().manual_seed(current_seed)
sampled_prompt = sample_prompt(lora_model)
final_prompt = prompt if prompt else sampled_prompt
if lora_model == "None":
with pipe.transformer.disable_adapter():
image = pipe(
prompt=final_prompt,
negative_prompt="",
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
else:
pipe.transformer.set_adapter(lora_model)
image = pipe(
prompt=final_prompt,
negative_prompt="",
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
images.append(image)
# Create a 2x2 grid from the 4 images
grid_image = create_grid_image(images)
return grid_image, ", ".join(map(str, seeds)), f"Model has been called {call_count} times"
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""
# SD3.5 Medium + Flow-GRPO
Our model is trained separately for different tasks, so it’s best to use the corresponding prompt format for each task.
**User Guide:**
1. Select a LoRA model (choose “None” to use the base model)
2. Click “Sample Prompt” to randomly select from ~1000 task-specific prompts, or write your own
3. Click “Run” to generate images (a 2×2 grid of 4 images will be produced)
**Note:**
- For the *Text Rendering* task, please enclose the text to be displayed in **double quotes (`"`)**, not single quotes (`'`)
""")
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
with gr.Row():
lora_model = gr.Dropdown(
label="LoRA Model",
choices=list(lora_models.keys()),
value="GenEval"
)
sample_prompt_button = gr.Button("Sample Prompt", scale=0, variant="secondary")
def update_sampled_prompt(lora_model):
return sample_prompt(lora_model)
sample_prompt_button.click(
fn=update_sampled_prompt,
inputs=[lora_model],
outputs=[prompt]
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Results (2x2 Grid)", show_label=True)
seed_display = gr.Textbox(label="Seeds Used", show_label=True)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Starting Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seeds", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=4.5, # Replace with defaults that work for your model
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=40, # Replace with defaults that work for your model
)
call_count_display = gr.Textbox(
label="Model Call Count",
value=f"Model has been called {get_call_count()} times",
interactive=False
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
lora_model,
],
outputs=[result, seed_display, call_count_display],
)
if __name__ == "__main__":
demo.launch()