File size: 1,092 Bytes
d12b093 2fc5937 d12b093 2fc5937 d12b093 2fc5937 d12b093 2fc5937 d12b093 2fc5937 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM
import gradio as gr
checkpoint = "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=True)
# Load GPTQ model correctly
model = AutoGPTQForCausalLM.from_quantized(
checkpoint,
device="cuda:0" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.float32,
trust_remote_code=True
)
# Function to format prompt + generate response
def predict(message, history):
prompt = f"<s>[INST] {message.strip()} [/INST]"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
do_sample=True,
temperature=0.7,
top_p=0.9,
max_new_tokens=256,
eos_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = decoded.split("[/INST]")[-1].strip()
return response
# Launch Gradio chatbot
gr.ChatInterface(predict).launch(debug=True)
demo.launch()
|