Spaces:
Paused
Paused
import os | |
import torch | |
from huggingface_hub import snapshot_download | |
from diffusers import ( | |
StableDiffusionPipeline, | |
DPMSolverMultistepScheduler, | |
AutoencoderKL, | |
UNet2DConditionModel, | |
) | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from peft import LoraConfig, get_peft_model | |
# βββ CONFIG βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
DATA_DIR = os.getenv("DATA_DIR", "./data") | |
MODEL_CACHE = os.getenv("MODEL_DIR", "./hidream-model") | |
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./lora-trained") | |
REPO_ID = "HiDream-ai/HiDream-I1-Dev" | |
# βββ STEP 1: ENSURE YOU HAVE A COMPLETE SNAPSHOT WITH CONFIGS βββββββββββββββββ | |
print(f"π₯ Downloading full model snapshot to {MODEL_CACHE}") | |
MODEL_ROOT = snapshot_download( | |
repo_id=REPO_ID, | |
local_dir=MODEL_CACHE, | |
local_dir_use_symlinks=False, # force a copy so config.json ends up there | |
) | |
# βββ STEP 2: LOAD SCHEDULER ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
print("π Loading scheduler") | |
scheduler = DPMSolverMultistepScheduler.from_pretrained( | |
MODEL_ROOT, | |
subfolder="scheduler", | |
) | |
# βββ STEP 3: LOAD VAE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
print("π Loading VAE") | |
vae = AutoencoderKL.from_pretrained( | |
MODEL_ROOT, | |
subfolder="vae", | |
torch_dtype=torch.float16, | |
).to("cuda") | |
# βββ STEP 4: LOAD TEXT ENCODER + TOKENIZER βββββββββββββββββββββββββββββββββββββ | |
print("π Loading text encoder + tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained( | |
MODEL_ROOT, | |
subfolder="text_encoder", | |
torch_dtype=torch.float16, | |
).to("cuda") | |
tokenizer = CLIPTokenizer.from_pretrained( | |
MODEL_ROOT, | |
subfolder="tokenizer", | |
) | |
# βββ STEP 5: LOAD UβNET βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
print("π Loading UβNet") | |
unet = UNet2DConditionModel.from_pretrained( | |
MODEL_ROOT, | |
subfolder="unet", | |
torch_dtype=torch.float16, | |
).to("cuda") | |
# βββ STEP 6: BUILD THE PIPELINE βββββββββββββββββββββββββββββββββββββββββββββββ | |
print("π Building StableDiffusionPipeline") | |
pipe = StableDiffusionPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
).to("cuda") | |
# βββ STEP 7: APPLY LORA ADAPTER βββββββββββββββββββββββββββββββββββββββββββββββ | |
print("π§ Applying LoRA adapter") | |
lora_config = LoraConfig( | |
r=16, | |
lora_alpha=16, | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
pipe.unet = get_peft_model(pipe.unet, lora_config) | |
# βββ STEP 8: YOUR TRAINING LOOP (SIMULATED) ββββββββββββββββββββββββββββββββββββ | |
print(f"π Loading dataset from: {DATA_DIR}") | |
for step in range(100): | |
# ββ hereβs where youβd load your images, run forward/backward, optimizer, etc. | |
print(f"Training step {step+1}/100") | |
# βββ STEP 9: SAVE THE FINEβTUNED LOβRA WEIGHTS βββββββββββββββββββββββββββββββ | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
pipe.save_pretrained(OUTPUT_DIR) | |
print("β Training complete. Saved to", OUTPUT_DIR) | |