Futuresony commited on
Commit
b7f8793
·
verified ·
1 Parent(s): 873b4c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -25
app.py CHANGED
@@ -1,46 +1,62 @@
 
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  from peft import PeftModel
 
4
 
5
- # Load base + LoRA model
6
- base_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
7
- lora_model = "Futuresony/future_12_10_2024"
 
8
 
9
- tokenizer = AutoTokenizer.from_pretrained(base_model)
10
- base = AutoModelForCausalLM.from_pretrained(base_model)
11
- model = PeftModel.from_pretrained(base, lora_model)
 
 
 
 
 
 
 
 
 
12
  model.eval()
13
 
14
- # Create generation pipeline
15
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
16
 
17
- # Define the chat function
18
- def respond(message, history, system_message, max_tokens, temperature, top_p):
19
- prompt = system_message + "\n"
20
- for user, bot in history:
21
- prompt += f"User: {user}\nAssistant: {bot}\n"
22
  prompt += f"User: {message}\nAssistant:"
23
 
24
- response = generator(
25
- prompt,
 
26
  max_new_tokens=max_tokens,
27
  temperature=temperature,
28
  top_p=top_p,
29
  do_sample=True,
30
- return_full_text=False,
31
- )[0]["generated_text"]
32
 
33
- yield response.strip()
 
 
34
 
35
- # Set up Gradio UI
36
  demo = gr.ChatInterface(
37
- respond,
38
  additional_inputs=[
39
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
40
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
41
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
42
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
43
  ],
 
 
44
  )
45
 
46
  if __name__ == "__main__":
 
1
+ import os
2
+ import torch
3
  import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from peft import PeftModel
6
+ from huggingface_hub import login
7
 
8
+ # Authenticate with Hugging Face using secret HF_TOKEN
9
+ hf_token = os.environ.get("HF_TOKEN")
10
+ if not hf_token:
11
+ raise RuntimeError("Missing HF_TOKEN in secrets. Please add it in your Space settings.")
12
 
13
+ login(token=hf_token)
14
+
15
+ # Load base model and LoRA adapter
16
+ base_model_id = "unsloth/gemma-2-9b" # Or your base model
17
+ lora_model_id = "Futuresony/future_12_10_2024" # Your LoRA fine-tuned model
18
+
19
+ # Load tokenizer and model
20
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id)
21
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_id, torch_dtype=torch.float16, device_map="auto")
22
+ model = PeftModel.from_pretrained(base_model, lora_model_id)
23
+
24
+ # Ensure model is in evaluation mode
25
  model.eval()
26
 
27
+ def generate_response(message, history, system_message, max_tokens, temperature, top_p):
28
+ prompt = system_message + "\n\n"
29
+
30
+ for user_input, bot_response in history:
31
+ prompt += f"User: {user_input}\nAssistant: {bot_response}\n"
32
 
 
 
 
 
 
33
  prompt += f"User: {message}\nAssistant:"
34
 
35
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
36
+ outputs = model.generate(
37
+ **inputs,
38
  max_new_tokens=max_tokens,
39
  temperature=temperature,
40
  top_p=top_p,
41
  do_sample=True,
42
+ pad_token_id=tokenizer.eos_token_id
43
+ )
44
 
45
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+ final_response = response.split("Assistant:")[-1].strip()
47
+ return final_response
48
 
49
+ # Gradio ChatInterface
50
  demo = gr.ChatInterface(
51
+ fn=generate_response,
52
  additional_inputs=[
53
+ gr.Textbox(value="You are a helpful assistant.", label="System Message"),
54
+ gr.Slider(50, 1024, value=256, step=1, label="Max Tokens"),
55
+ gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"),
56
+ gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
57
  ],
58
+ title="LoRA AI Chat Assistant",
59
+ description="Chat with your fine-tuned model using LoRA adapter."
60
  )
61
 
62
  if __name__ == "__main__":