|
import gradio as gr |
|
from diffusers import AutoPipelineForText2Image |
|
import torch |
|
import os |
|
|
|
|
|
MODEL_NAME = "katuni4ka/tiny-random-flex.2-preview" |
|
CACHE_DIR = "./model_cache" |
|
|
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
|
|
|
pipe = AutoPipelineForText2Image.from_pretrained( |
|
MODEL_NAME, |
|
torch_dtype=torch.float16, |
|
cache_dir=CACHE_DIR, |
|
max_seq_length=512 |
|
).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
ASPECT_RATIOS = { |
|
"Square (512x512)": (512, 512), |
|
"Landscape (1024x512)": (1024, 512), |
|
"Portrait (512x1024)": (512, 1024), |
|
"A4 (768x1024)": (768, 1024) |
|
} |
|
|
|
def generate_image(prompt, aspect_ratio): |
|
"""Generate image with Flux-specific parameters""" |
|
width, height = ASPECT_RATIOS[aspect_ratio] |
|
|
|
try: |
|
with torch.inference_mode(): |
|
image = pipe( |
|
prompt=prompt, |
|
width=width, |
|
height=height, |
|
num_inference_steps=20, |
|
guidance_scale=4.5, |
|
generator=torch.Generator(device="cuda").manual_seed(42) |
|
).images[0] |
|
return image |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
with gr.Blocks(theme="huggingface", analytics_enabled=False) as demo: |
|
gr.Markdown(""" |
|
# Tiny Random Flex Text-to-Image Generator |
|
Experimental Flux-based model with critical fixes for tensor shape errors |
|
|
|
π§ Important: This model requires specific input dimensions and has limited capabilities |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
placeholder="Try simple prompts like 'a colorful pattern'", |
|
lines=2 |
|
) |
|
aspect_ratio = gr.Dropdown( |
|
label="Aspect Ratio", |
|
choices=list(ASPECT_RATIOS.keys()), |
|
value="Square (512x512)" |
|
) |
|
generate_btn = gr.Button("π¨ Generate Image", variant="primary") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(label="Generated Image", interactive=False) |
|
|
|
generate_btn.click( |
|
fn=generate_image, |
|
inputs=[prompt, aspect_ratio], |
|
outputs=output_image |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |