File size: 4,312 Bytes
91df19c
262d1c8
 
 
041e285
262d1c8
 
e348d3e
262d1c8
 
954a751
 
262d1c8
954a751
262d1c8
 
 
 
 
954a751
96c33c2
d0afb9d
262d1c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0afb9d
262d1c8
 
724f6b3
262d1c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
954a751
 
 
 
 
262d1c8
 
954a751
 
262d1c8
 
 
 
 
 
 
 
 
 
 
 
 
954a751
262d1c8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import json
import subprocess
from threading import Thread

import torch
import spaces
from peft import PeftModel
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig

subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

MODEL_ID = "meta-llama/Llama-2-7b-hf"
CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE")
CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH"))
COLOR = os.environ.get("COLOR")
DESCRIPTION = os.environ.get("DESCRIPTION")
LORA_WEIGHTS = "DSMI/LLaMA-E"
access_token = os.environ.get('HF_TOKEN')

@spaces.GPU(duration=120)
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
    # Format history with a given chat template
    if CHAT_TEMPLATE == "Auto":
        stop_tokens = [tokenizer.eos_token_id]
        instruction = []
        for user, assistant in history:
            instruction.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
        instruction.append({"role": "user", "content": message})
    elif CHAT_TEMPLATE == "ChatML":
        stop_tokens = ["<|endoftext|>", "<|im_end|>"]
        instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
        for user, assistant in history:
            instruction += '<|im_start|>user\n' + user + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
        instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
    elif CHAT_TEMPLATE == "Mistral Instruct":
        stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
        instruction = '<s>[INST] ' + system_prompt
        for user, assistant in history:
            instruction += user + ' [/INST] ' + assistant + '</s>[INST]'
        instruction += ' ' + message + ' [/INST]'
    else:
        raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
    print(instruction)
    
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
    input_ids, attention_mask = enc.input_ids, enc.attention_mask

    if input_ids.shape[1] > CONTEXT_LENGTH:
        input_ids = input_ids[:, -CONTEXT_LENGTH:]

    generate_kwargs = dict(
        {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
        streamer=streamer,
        do_sample=True,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        top_p=top_p
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    outputs = []
    for new_token in streamer:
        outputs.append(new_token)
        if new_token in stop_tokens:
            break
        yield "".join(outputs)


# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
    load_in_4bit=False,
    bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", token=access_token)
model = LlamaForCausalLM.from_pretrained(
        MODEL_ID,
        load_in_8bit=False,
        torch_dtype=torch.float16,
        device_map="auto",
    )

model = PeftModel.from_pretrained(
        model, LORA_WEIGHTS, torch_dtype=torch.float16, force_download=True
    )

# Create Gradio interface
gr.ChatInterface(
    predict,
    title= "πŸ¦™πŸ›οΈ LLaMA-E",
    description=DESCRIPTION,
    additional_inputs_accordion=gr.Accordion(label="βš™οΈ Parameters", open=False),
    additional_inputs=[
        gr.Textbox("You are HelpingAI a emotional AI always answer my question in HelpingAI style", label="System prompt"),
        gr.Slider(0, 1, 0.8, label="Temperature"),
        gr.Slider(128, 4096, 1024, label="Max new tokens"),
        gr.Slider(1, 80, 40, label="Top K sampling"),
        gr.Slider(0, 2, 1.1, label="Repetition penalty"),
        gr.Slider(0, 1, 0.95, label="Top P sampling"),
    ],
    theme=gr.themes.Soft(primary_hue=COLOR),
).queue().launch()