Spaces:
Running
Running
import streamlit as st | |
from PIL import Image | |
import torch | |
import os | |
import time | |
import tempfile | |
from pathlib import Path | |
from huggingface_hub import snapshot_download | |
# === Model Wrapper Class === | |
class ImageGenerator: | |
def __init__(self, ae_path, dit_path, qwen2vl_model_path, max_length=640): | |
self.ae_path = ae_path | |
self.dit_path = dit_path | |
self.qwen2vl_model_path = qwen2vl_model_path | |
self.max_length = max_length | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.load_model() | |
def load_model(self): | |
# Dummy placeholder - replace with actual model loading logic | |
pass | |
def to_cuda(self): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model = torch.load(self.ae_path, map_location=self.device) | |
# Add actual model load logic as needed | |
# === Inference Function === | |
def inference(prompt, image, seed, size_level, model): | |
result_image = image # Placeholder - Replace with actual inference logic | |
used_seed = seed if seed != -1 else int(time.time()) | |
return result_image, used_seed | |
# === Streamlit UI Setup === | |
st.set_page_config(page_title="Ghibli style", layout="centered") | |
st.title("πΌοΈ Ghibli style for Free : AI Image Editing") | |
st.markdown("Generate Studio Ghibli style illustrations from your image using AI.") | |
prompt = "Turn into an illustration in Studio Ghibli style" | |
uploaded_image = st.file_uploader("π€ Upload an Image", type=["jpg", "jpeg", "png"]) | |
seed = st.number_input("π² Random Seed (-1 for random)", value=-1, step=1) | |
size_level = st.number_input("π Size Level (minimum 512)", value=512, min_value=512, step=32) | |
generate_button = st.button("π Generate") | |
def load_model(): | |
repo = "stepfun-ai/Step1X-Edit" | |
local_dir = Path.home() / "step1x_weights" | |
local_dir.mkdir(exist_ok=True) | |
# Removed deprecated `local_dir_use_symlinks` parameter | |
snapshot_download(repo_id=repo, local_dir=local_dir) | |
model = ImageGenerator( | |
ae_path=local_dir / 'vae.safetensors', | |
dit_path=local_dir / "step1x-edit-i1258.safetensors", | |
qwen2vl_model_path='Qwen/Qwen2.5-VL-7B-Instruct', | |
max_length=640 | |
) | |
return model | |
image_edit_model = load_model() | |
# === Handle Generation === | |
if generate_button and uploaded_image is not None: | |
input_image = Image.open(uploaded_image).convert("RGB") | |
input_image.thumbnail((size_level, size_level)) | |
with st.spinner("π Generating edited image..."): | |
start = time.time() | |
try: | |
result_image, used_seed = inference(prompt, input_image, seed, size_level, image_edit_model) | |
end = time.time() | |
st.success(f"β Done in {end - start:.2f} seconds β Seed used: {used_seed}") | |
with tempfile.NamedTemporaryFile(dir="/tmp", delete=False, suffix=".png") as temp_file: | |
result_image.save(temp_file.name) | |
st.image(temp_file.name, caption="πΌοΈ Edited Image", use_column_width=True) | |
except Exception as e: | |
st.error(f"β Inference failed: {e}") | |
st.stop() | |