LoRa_Streamlit / train.py
ramimu's picture
Update train.py
35bd3cf verified
raw
history blame
3.95 kB
import os
import torch
from huggingface_hub import snapshot_download
from diffusers import (
StableDiffusionPipeline,
DPMSolverMultistepScheduler,
AutoencoderKL,
UNet2DConditionModel,
)
from transformers import CLIPTextModel, CLIPTokenizer
from peft import LoraConfig, get_peft_model
# ─── CONFIG ───────────────────────────────────────────────────────────────────
DATA_DIR = os.getenv("DATA_DIR", "./data")
MODEL_CACHE = os.getenv("MODEL_DIR", "./hidream-model")
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained")
REPO_ID = "HiDream-ai/HiDream-I1-Dev"
# ─── STEP 1: ENSURE YOU HAVE A COMPLETE SNAPSHOT WITH CONFIGS ─────────────────
print(f"πŸ“₯ Downloading full model snapshot to {MODEL_CACHE}")
MODEL_ROOT = snapshot_download(
repo_id=REPO_ID,
local_dir=MODEL_CACHE,
local_dir_use_symlinks=False, # force a copy so config.json ends up there
)
# ─── STEP 2: LOAD SCHEDULER ────────────────────────────────────────────────────
print("πŸ”„ Loading scheduler")
scheduler = DPMSolverMultistepScheduler.from_pretrained(
MODEL_ROOT,
subfolder="scheduler",
)
# ─── STEP 3: LOAD VAE ──────────────────────────────────────────────────────────
print("πŸ”„ Loading VAE")
vae = AutoencoderKL.from_pretrained(
MODEL_ROOT,
subfolder="vae",
torch_dtype=torch.float16,
).to("cuda")
# ─── STEP 4: LOAD TEXT ENCODER + TOKENIZER ─────────────────────────────────────
print("πŸ”„ Loading text encoder + tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
MODEL_ROOT,
subfolder="text_encoder",
torch_dtype=torch.float16,
).to("cuda")
tokenizer = CLIPTokenizer.from_pretrained(
MODEL_ROOT,
subfolder="tokenizer",
)
# ─── STEP 5: LOAD U‑NET ───────────────────────────────────────────────────────
print("πŸ”„ Loading U‑Net")
unet = UNet2DConditionModel.from_pretrained(
MODEL_ROOT,
subfolder="unet",
torch_dtype=torch.float16,
).to("cuda")
# ─── STEP 6: BUILD THE PIPELINE ───────────────────────────────────────────────
print("🌟 Building StableDiffusionPipeline")
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
).to("cuda")
# ─── STEP 7: APPLY LORA ADAPTER ───────────────────────────────────────────────
print("🧠 Applying LoRA adapter")
lora_config = LoraConfig(
r=16,
lora_alpha=16,
bias="none",
task_type="CAUSAL_LM",
)
pipe.unet = get_peft_model(pipe.unet, lora_config)
# ─── STEP 8: YOUR TRAINING LOOP (SIMULATED) ────────────────────────────────────
print(f"πŸ“‚ Loading dataset from: {DATA_DIR}")
for step in range(100):
# ←– here’s where you’d load your images, run forward/backward, optimizer, etc.
print(f"Training step {step+1}/100")
# ─── STEP 9: SAVE THE FINE‑TUNED LO‑RA WEIGHTS ───────────────────────────────
os.makedirs(OUTPUT_DIR, exist_ok=True)
pipe.save_pretrained(OUTPUT_DIR)
print("βœ… Training complete. Saved to", OUTPUT_DIR)