Ghiblistyle-Free / src /streamlit_app.py
UDface11jkj's picture
Update src/streamlit_app.py
d45f61f verified
import streamlit as st
from PIL import Image
import torch
import os
import time
import tempfile
from pathlib import Path
from huggingface_hub import snapshot_download
# === Model Wrapper Class ===
class ImageGenerator:
def __init__(self, ae_path, dit_path, qwen2vl_model_path, max_length=640):
self.ae_path = ae_path
self.dit_path = dit_path
self.qwen2vl_model_path = qwen2vl_model_path
self.max_length = max_length
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.load_model()
def load_model(self):
# Dummy placeholder - replace with actual model loading logic
pass
def to_cuda(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = torch.load(self.ae_path, map_location=self.device)
# Add actual model load logic as needed
# === Inference Function ===
def inference(prompt, image, seed, size_level, model):
result_image = image # Placeholder - Replace with actual inference logic
used_seed = seed if seed != -1 else int(time.time())
return result_image, used_seed
# === Streamlit UI Setup ===
st.set_page_config(page_title="Ghibli style", layout="centered")
st.title("πŸ–ΌοΈ Ghibli style for Free : AI Image Editing")
st.markdown("Generate Studio Ghibli style illustrations from your image using AI.")
prompt = "Turn into an illustration in Studio Ghibli style"
uploaded_image = st.file_uploader("πŸ“€ Upload an Image", type=["jpg", "jpeg", "png"])
seed = st.number_input("🎲 Random Seed (-1 for random)", value=-1, step=1)
size_level = st.number_input("πŸ“ Size Level (minimum 512)", value=512, min_value=512, step=32)
generate_button = st.button("πŸš€ Generate")
@st.cache_resource
def load_model():
repo = "stepfun-ai/Step1X-Edit"
local_dir = Path.home() / "step1x_weights"
local_dir.mkdir(exist_ok=True)
# Removed deprecated `local_dir_use_symlinks` parameter
snapshot_download(repo_id=repo, local_dir=local_dir)
model = ImageGenerator(
ae_path=local_dir / 'vae.safetensors',
dit_path=local_dir / "step1x-edit-i1258.safetensors",
qwen2vl_model_path='Qwen/Qwen2.5-VL-7B-Instruct',
max_length=640
)
return model
image_edit_model = load_model()
# === Handle Generation ===
if generate_button and uploaded_image is not None:
input_image = Image.open(uploaded_image).convert("RGB")
input_image.thumbnail((size_level, size_level))
with st.spinner("πŸ”„ Generating edited image..."):
start = time.time()
try:
result_image, used_seed = inference(prompt, input_image, seed, size_level, image_edit_model)
end = time.time()
st.success(f"βœ… Done in {end - start:.2f} seconds β€” Seed used: {used_seed}")
with tempfile.NamedTemporaryFile(dir="/tmp", delete=False, suffix=".png") as temp_file:
result_image.save(temp_file.name)
st.image(temp_file.name, caption="πŸ–ΌοΈ Edited Image", use_column_width=True)
except Exception as e:
st.error(f"❌ Inference failed: {e}")
st.stop()