File size: 4,312 Bytes
91df19c 262d1c8 041e285 262d1c8 e348d3e 262d1c8 954a751 262d1c8 954a751 262d1c8 954a751 96c33c2 d0afb9d 262d1c8 d0afb9d 262d1c8 724f6b3 262d1c8 954a751 262d1c8 954a751 262d1c8 954a751 262d1c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
import json
import subprocess
from threading import Thread
import torch
import spaces
from peft import PeftModel
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
MODEL_ID = "meta-llama/Llama-2-7b-hf"
CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE")
CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH"))
COLOR = os.environ.get("COLOR")
DESCRIPTION = os.environ.get("DESCRIPTION")
LORA_WEIGHTS = "DSMI/LLaMA-E"
access_token = os.environ.get('HF_TOKEN')
@spaces.GPU(duration=120)
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
# Format history with a given chat template
if CHAT_TEMPLATE == "Auto":
stop_tokens = [tokenizer.eos_token_id]
instruction = []
for user, assistant in history:
instruction.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
instruction.append({"role": "user", "content": message})
elif CHAT_TEMPLATE == "ChatML":
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
for user, assistant in history:
instruction += '<|im_start|>user\n' + user + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
elif CHAT_TEMPLATE == "Mistral Instruct":
stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
instruction = '<s>[INST] ' + system_prompt
for user, assistant in history:
instruction += user + ' [/INST] ' + assistant + '</s>[INST]'
instruction += ' ' + message + ' [/INST]'
else:
raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
print(instruction)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
input_ids, attention_mask = enc.input_ids, enc.attention_mask
if input_ids.shape[1] > CONTEXT_LENGTH:
input_ids = input_ids[:, -CONTEXT_LENGTH:]
generate_kwargs = dict(
{"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
streamer=streamer,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
if new_token in stop_tokens:
break
yield "".join(outputs)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
load_in_4bit=False,
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", token=access_token)
model = LlamaForCausalLM.from_pretrained(
MODEL_ID,
load_in_8bit=False,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model, LORA_WEIGHTS, torch_dtype=torch.float16, force_download=True
)
# Create Gradio interface
gr.ChatInterface(
predict,
title= "π¦ποΈ LLaMA-E",
description=DESCRIPTION,
additional_inputs_accordion=gr.Accordion(label="βοΈ Parameters", open=False),
additional_inputs=[
gr.Textbox("You are HelpingAI a emotional AI always answer my question in HelpingAI style", label="System prompt"),
gr.Slider(0, 1, 0.8, label="Temperature"),
gr.Slider(128, 4096, 1024, label="Max new tokens"),
gr.Slider(1, 80, 40, label="Top K sampling"),
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
theme=gr.themes.Soft(primary_hue=COLOR),
).queue().launch() |