sagar007 commited on
Commit
cc1b568
·
verified ·
1 Parent(s): 1f7ba92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -1,16 +1,23 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, pipeline
3
  import torch
 
 
4
 
 
 
 
 
5
  model_name = "akjindal53244/Llama-3.1-Storm-8B"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- pipeline = pipeline(
8
- "text-generation",
9
- model=model_name,
10
  torch_dtype=torch.bfloat16,
11
- device_map="auto",
 
12
  )
13
 
 
14
  def generate_text(prompt, max_length, temperature):
15
  messages = [
16
  {"role": "system", "content": "You are a helpful assistant."},
@@ -18,8 +25,10 @@ def generate_text(prompt, max_length, temperature):
18
  ]
19
  formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
20
 
21
- outputs = pipeline(
22
- formatted_prompt,
 
 
23
  max_new_tokens=max_length,
24
  do_sample=True,
25
  temperature=temperature,
@@ -27,7 +36,7 @@ def generate_text(prompt, max_length, temperature):
27
  top_p=0.95,
28
  )
29
 
30
- return outputs[0]["generated_text"]
31
 
32
  iface = gr.Interface(
33
  fn=generate_text,
 
1
  import gradio as gr
2
+ import spaces
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import subprocess
6
 
7
+ # Install flash-attn
8
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
+
10
+ # Load the model and tokenizer
11
  model_name = "akjindal53244/Llama-3.1-Storm-8B"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ model_name,
 
15
  torch_dtype=torch.bfloat16,
16
+ use_flash_attention_2=True,
17
+ device_map="auto"
18
  )
19
 
20
+ @spaces.GPU(duration=120)
21
  def generate_text(prompt, max_length, temperature):
22
  messages = [
23
  {"role": "system", "content": "You are a helpful assistant."},
 
25
  ]
26
  formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
27
 
28
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cuda")
29
+
30
+ outputs = model.generate(
31
+ **inputs,
32
  max_new_tokens=max_length,
33
  do_sample=True,
34
  temperature=temperature,
 
36
  top_p=0.95,
37
  )
38
 
39
+ return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
40
 
41
  iface = gr.Interface(
42
  fn=generate_text,