KaizeShi commited on
Commit
262d1c8
·
1 Parent(s): 041e285

Add application file

Browse files
Files changed (1) hide show
  1. app.py +89 -130
app.py CHANGED
@@ -1,148 +1,107 @@
1
- import spaces
2
- import torch
3
- from peft import PeftModel
4
- import transformers
5
- import gradio as gr
6
  import os
 
 
 
7
 
8
-
9
- assert (
10
- "LlamaTokenizer" in transformers._import_structure["models.llama"]
11
- ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
12
  from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
13
- access_token = os.environ.get('HF_TOKEN')
14
 
15
- tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", token=access_token)
16
 
17
- BASE_MODEL = "meta-llama/Llama-2-7b-hf"
 
 
 
 
18
  LORA_WEIGHTS = "DSMI/LLaMA-E"
19
 
20
- if torch.cuda.is_available():
21
- device = "cuda"
22
- else:
23
- device = "cpu"
24
 
25
- try:
26
- if torch.backends.mps.is_available():
27
- device = "mps"
28
- except:
29
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- print("Device: " + str(device))
 
32
 
33
- if device == "cuda":
34
- model = LlamaForCausalLM.from_pretrained(
35
- BASE_MODEL,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  load_in_8bit=False,
37
  torch_dtype=torch.float16,
38
  device_map="auto",
39
  )
40
- model = PeftModel.from_pretrained(
41
- model, LORA_WEIGHTS, torch_dtype=torch.float16, force_download=True
42
- )
43
- elif device == "mps":
44
- model = LlamaForCausalLM.from_pretrained(
45
- BASE_MODEL,
46
- device_map={"": device},
47
- torch_dtype=torch.float16,
48
- )
49
- model = PeftModel.from_pretrained(
50
- model,
51
- LORA_WEIGHTS,
52
- device_map={"": device},
53
- torch_dtype=torch.float16,
54
- )
55
- else:
56
- model = LlamaForCausalLM.from_pretrained(
57
- BASE_MODEL,
58
- device_map={"": device},
59
- low_cpu_mem_usage=True
60
- )
61
- model = PeftModel.from_pretrained(
62
- model,
63
- LORA_WEIGHTS,
64
- device_map={"": device},
65
- )
66
 
67
- print("Model: " + str(model))
68
-
69
- def generate_prompt(instruction, input=None):
70
- if input:
71
- return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
72
- ### Instruction:
73
- {instruction}
74
- ### Input:
75
- {input}
76
- ### Response:"""
77
- else:
78
- return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
79
- ### Instruction:
80
- {instruction}
81
- ### Response:"""
82
-
83
- if device != "cpu":
84
- model.half()
85
- model.eval()
86
- if torch.__version__ >= "2":
87
- model = torch.compile(model)
88
-
89
- @spaces.GPU()
90
- def evaluate(
91
- instruction,
92
- input=None,
93
- temperature=0.1,
94
- top_p=0.75,
95
- top_k=40,
96
- num_beams=2,
97
- max_new_tokens=64,
98
- **kwargs,
99
- ):
100
- prompt = generate_prompt(instruction, input)
101
- inputs = tokenizer(prompt, return_tensors="pt")
102
- input_ids = inputs["input_ids"].to(device)
103
- generation_config = GenerationConfig(
104
- temperature=temperature,
105
- top_p=top_p,
106
- top_k=top_k,
107
- num_beams=num_beams,
108
- **kwargs,
109
  )
110
- with torch.no_grad():
111
- generation_output = model.generate(
112
- input_ids=input_ids,
113
- generation_config=generation_config,
114
- return_dict_in_generate=True,
115
- output_scores=True,
116
- max_new_tokens=max_new_tokens,
117
- )
118
- s = generation_output.sequences[0]
119
- output = tokenizer.decode(s)
120
- return output.split("### Response:")[1].strip()
121
-
122
 
123
- g = gr.Interface(
124
- fn=evaluate,
125
- inputs=[
126
- gr.components.Textbox(
127
- lines=2, label="Instruction", placeholder="Tell me about alpacas."
128
- ),
129
- gr.components.Textbox(lines=2, label="Input", placeholder="none"),
130
- gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
131
- gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
132
- gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
133
- gr.components.Slider(minimum=1, maximum=4, step=1, value=4, label="Beams"),
134
- gr.components.Slider(
135
- minimum=1, maximum=512, step=1, value=128, label="Max tokens"
136
- ),
137
  ],
138
- outputs=[
139
- gr.Textbox(
140
- lines=5,
141
- label="Output",
142
- )
143
- ],
144
- title="🦙🛍️ LLaMA-E",
145
- description="LLaMA-E is a series of fine-tuned LLaMA model following the E-commerce instructions. It is developed by DSMI (http://dsmi.tech/) @ University of Technology Sydney, and trained on the 120k instruction set. This model is for academic research use only. For more details please contact: Kaize.Shi@uts.edu.au",
146
- )
147
- g.queue(concurrency_count=1)
148
- g.launch()
 
 
 
 
 
 
1
  import os
2
+ import json
3
+ import subprocess
4
+ from threading import Thread
5
 
6
+ import torch
7
+ import spaces
8
+ import gradio as gr
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
10
  from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
 
11
 
12
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
13
 
14
+ MODEL_ID = "meta-llama/Llama-2-7b-hf"
15
+ CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE")
16
+ CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH"))
17
+ COLOR = os.environ.get("COLOR")
18
+ DESCRIPTION = os.environ.get("DESCRIPTION")
19
  LORA_WEIGHTS = "DSMI/LLaMA-E"
20
 
 
 
 
 
21
 
22
+ @spaces.GPU(duration=120)
23
+ def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
24
+ # Format history with a given chat template
25
+ if CHAT_TEMPLATE == "Auto":
26
+ stop_tokens = [tokenizer.eos_token_id]
27
+ instruction = []
28
+ for user, assistant in history:
29
+ instruction.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
30
+ instruction.append({"role": "user", "content": message})
31
+ elif CHAT_TEMPLATE == "ChatML":
32
+ stop_tokens = ["<|endoftext|>", "<|im_end|>"]
33
+ instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
34
+ for user, assistant in history:
35
+ instruction += '<|im_start|>user\n' + user + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
36
+ instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
37
+ elif CHAT_TEMPLATE == "Mistral Instruct":
38
+ stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
39
+ instruction = '<s>[INST] ' + system_prompt
40
+ for user, assistant in history:
41
+ instruction += user + ' [/INST] ' + assistant + '</s>[INST]'
42
+ instruction += ' ' + message + ' [/INST]'
43
+ else:
44
+ raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
45
+ print(instruction)
46
+
47
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
48
+ enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
49
+ input_ids, attention_mask = enc.input_ids, enc.attention_mask
50
 
51
+ if input_ids.shape[1] > CONTEXT_LENGTH:
52
+ input_ids = input_ids[:, -CONTEXT_LENGTH:]
53
 
54
+ generate_kwargs = dict(
55
+ {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
56
+ streamer=streamer,
57
+ do_sample=True,
58
+ temperature=temperature,
59
+ max_new_tokens=max_new_tokens,
60
+ top_k=top_k,
61
+ repetition_penalty=repetition_penalty,
62
+ top_p=top_p
63
+ )
64
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
65
+ t.start()
66
+ outputs = []
67
+ for new_token in streamer:
68
+ outputs.append(new_token)
69
+ if new_token in stop_tokens:
70
+ break
71
+ yield "".join(outputs)
72
+
73
+
74
+ # Load model
75
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
76
+ quantization_config = BitsAndBytesConfig(
77
+ load_in_4bit=False,
78
+ bnb_4bit_compute_dtype=torch.bfloat16
79
+ )
80
+ tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", token=access_token)
81
+ model = LlamaForCausalLM.from_pretrained(
82
+ MODEL_ID,
83
  load_in_8bit=False,
84
  torch_dtype=torch.float16,
85
  device_map="auto",
86
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ model = PeftModel.from_pretrained(
89
+ model, LORA_WEIGHTS, torch_dtype=torch.float16, force_download=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # Create Gradio interface
93
+ gr.ChatInterface(
94
+ predict,
95
+ title= "🦙🛍️ LLaMA-E",
96
+ description=DESCRIPTION,
97
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
98
+ additional_inputs=[
99
+ gr.Textbox("You are HelpingAI a emotional AI always answer my question in HelpingAI style", label="System prompt"),
100
+ gr.Slider(0, 1, 0.8, label="Temperature"),
101
+ gr.Slider(128, 4096, 1024, label="Max new tokens"),
102
+ gr.Slider(1, 80, 40, label="Top K sampling"),
103
+ gr.Slider(0, 2, 1.1, label="Repetition penalty"),
104
+ gr.Slider(0, 1, 0.95, label="Top P sampling"),
 
105
  ],
106
+ theme=gr.themes.Soft(primary_hue=COLOR),
107
+ ).queue().launch()