rinrikatoki's picture
Upload app.py
be1a16c verified
raw
history blame
2.64 kB
import os
import zipfile
from huggingface_hub import login
import torch
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM
# 🔹 گرفتن توکن از محیط
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise ValueError("❌ HF_TOKEN not found in environment secrets.")
# 🔹 لاگین به HuggingFace
login(hf_token)
# 🔹 مسیر فایل زیپ
LORA_ZIP_PATH = "dorna-diabetes-finetuned-20250514T183411Z-1-001.zip"
EXTRACT_DIR = "lora" # پوشه‌ای که فایل‌ها داخلش اکسترکت می‌شن
# ✅ اکسترکت فایل ZIP در صورت نیاز
if not os.path.exists(EXTRACT_DIR):
with zipfile.ZipFile(LORA_ZIP_PATH, 'r') as zip_ref:
zip_ref.extractall(EXTRACT_DIR)
print("✅ فایل LoRA اکسترکت شد.")
# ✅ پیدا کردن مسیر فولدر واقعی شامل adapter_config.json
# معمولاً zip یک زیرپوشه می‌سازه داخل خودش
for root, dirs, files in os.walk(EXTRACT_DIR):
if "adapter_config.json" in files:
LORA_PATH = root
break
else:
raise FileNotFoundError("❌ adapter_config.json در هیچ زیرپوشه‌ای یافت نشد.")
# ✅ تغییر نام فایل .safetensors به adapter_model.safetensors اگر لازم بود
for filename in os.listdir(LORA_PATH):
if filename.endswith(".safetensors") and filename != "adapter_model.safetensors":
os.rename(
os.path.join(LORA_PATH, filename),
os.path.join(LORA_PATH, "adapter_model.safetensors")
)
print("✅ اسم فایل تغییر کرد.")
break
# 🔹 بارگذاری مدل و توکنایزر
print("🔹 در حال بارگذاری مدل پایه + LoRA...")
model = AutoPeftModelForCausalLM.from_pretrained(
LORA_PATH,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
token=hf_token,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model.base_model.config._name_or_path,
token=hf_token,
trust_remote_code=True
)
print("✅ مدل و توکنایزر با موفقیت بارگذاری شدند.")
while True:
prompt = input("📝 یک دستور وارد کن (exit برای خروج): ")
if prompt.lower() == "exit":
break
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=200)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("🧠 پاسخ:", response)