requirement file
Browse files
app.py
CHANGED
@@ -3,17 +3,17 @@ from transformers import AutoTokenizer
|
|
3 |
from auto_gptq import AutoGPTQForCausalLM
|
4 |
import gradio as gr
|
5 |
|
6 |
-
checkpoint = "
|
7 |
|
8 |
# Load tokenizer
|
9 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=True)
|
10 |
|
11 |
-
# Load GPTQ model
|
12 |
-
|
13 |
model = AutoGPTQForCausalLM.from_quantized(
|
14 |
checkpoint,
|
15 |
device="cuda:0" if torch.cuda.is_available() else "cpu",
|
16 |
-
torch_dtype=torch.
|
|
|
17 |
)
|
18 |
|
19 |
# Function to format prompt + generate response
|
@@ -24,7 +24,7 @@ def predict(message, history):
|
|
24 |
outputs = model.generate(
|
25 |
**inputs,
|
26 |
do_sample=True,
|
27 |
-
temperature=0.
|
28 |
top_p=0.9,
|
29 |
max_new_tokens=256,
|
30 |
eos_token_id=tokenizer.eos_token_id
|
@@ -35,4 +35,9 @@ def predict(message, history):
|
|
35 |
return response
|
36 |
|
37 |
# Launch Gradio chatbot
|
38 |
-
gr.ChatInterface(predict
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from auto_gptq import AutoGPTQForCausalLM
|
4 |
import gradio as gr
|
5 |
|
6 |
+
checkpoint = "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ"
|
7 |
|
8 |
# Load tokenizer
|
9 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=True)
|
10 |
|
11 |
+
# Load GPTQ model correctly
|
|
|
12 |
model = AutoGPTQForCausalLM.from_quantized(
|
13 |
checkpoint,
|
14 |
device="cuda:0" if torch.cuda.is_available() else "cpu",
|
15 |
+
torch_dtype=torch.float32,
|
16 |
+
trust_remote_code=True
|
17 |
)
|
18 |
|
19 |
# Function to format prompt + generate response
|
|
|
24 |
outputs = model.generate(
|
25 |
**inputs,
|
26 |
do_sample=True,
|
27 |
+
temperature=0.7,
|
28 |
top_p=0.9,
|
29 |
max_new_tokens=256,
|
30 |
eos_token_id=tokenizer.eos_token_id
|
|
|
35 |
return response
|
36 |
|
37 |
# Launch Gradio chatbot
|
38 |
+
gr.ChatInterface(predict).launch(debug=True)
|
39 |
+
|
40 |
+
|
41 |
+
demo.launch()
|
42 |
+
|
43 |
+
|