Nastarang commited on
Commit
d12b093
·
1 Parent(s): f79c0a2

Add chatbpot file

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from auto_gptq import AutoGPTQForCausalLM
4
+ import gradio as gr
5
+
6
+ checkpoint = "cortecs/Meta-Llama-3-8B-Instruct-GPTQ-8b"
7
+
8
+ # Load tokenizer
9
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=True)
10
+
11
+ # Load GPTQ model
12
+
13
+ model = AutoGPTQForCausalLM.from_quantized(
14
+ checkpoint,
15
+ device="cuda:0" if torch.cuda.is_available() else "cpu",
16
+ torch_dtype=torch.float16,
17
+ )
18
+
19
+ # Function to format prompt + generate response
20
+ def predict(message, history):
21
+ prompt = f"<s>[INST] {message.strip()} [/INST]"
22
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
23
+
24
+ outputs = model.generate(
25
+ **inputs,
26
+ do_sample=True,
27
+ temperature=0.6,
28
+ top_p=0.9,
29
+ max_new_tokens=256,
30
+ eos_token_id=tokenizer.eos_token_id
31
+ )
32
+
33
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
+ response = decoded.split("[/INST]")[-1].strip()
35
+ return response
36
+
37
+ # Launch Gradio chatbot
38
+ gr.ChatInterface(predict, title=" LLaMA 3 Chatbot").launch(debug=True)