AppTry / app.py
Wh1plashR's picture
fine-tune model
209742c verified
raw
history blame
2.84 kB
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# Model initialization
def setup_model():
# Load base model in 4-bit or full precision depending on availability
base = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen-2.5B-0.5", device_map="auto", trust_remote_code=True
)
# Load LoRA adapters
model = PeftModel.from_pretrained(
base,
"Wh1plashR/qwen-energy-lora",
device_map="auto",
trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Wh1plashR/qwen-energy-lora", use_fast=True)
# Set to eval and optionally compile (requires PyTorch 2+)
model.eval()
try:
model = torch.compile(model)
except Exception:
pass
return tokenizer, model
# Initialize
tokenizer, model = setup_model()
# Prompt prefix
prompt_prefix = """
You are an energy‑saving expert tasked to help households reduce their monthly electricity bills.
Given the user's appliance usage information (device name, wattage, hours used per day, days used per week):
1. Flag the highest energy consumers.
2. Recommend practical, empathetic, achievable actions.
3. Suggest appliance swaps (e.g. LED, inverter AC) and habit changes.
Give at most 5 suggestions and format with bullet points that is <= 100 tokens.
Don't add anything to the response besides the recommendation.
Here is the user's input:
"""
def generate_recommendation(appliance_info: str) -> str:
# Build prompt
prompt = prompt_prefix + appliance_info.strip() + "\n\nRecommendations:"
# Tokenize and move to model device
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
do_sample=False,
temperature=0.0,
use_cache=True
)
# Decode and clean
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract recommendations
rec = text.split("Recommendations:")[-1].strip()
# Remove any trailing notes
rec = rec.split("Note:")[0].strip()
# Clean empty lines
cleaned = "\n".join(line.strip() for line in rec.splitlines() if line.strip())
return cleaned
# Gradio interface
def main():
iface = gr.Interface(
fn=generate_recommendation,
inputs=gr.Textbox(lines=10, placeholder="Enter appliance usage details..."),
outputs=gr.Textbox(label="Energy-Saving Recommendations"),
title="Energy-Saving Recommendation Generator",
description="Provide appliance usage details to receive actionable energy-saving tips."
)
iface.launch()
if __name__ == "__main__":
main()