sagar007 commited on
Commit
5195372
·
verified ·
1 Parent(s): 985eabb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -1,6 +1,11 @@
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, LlamaForCausalLM
 
 
 
 
4
 
5
  # Initialize model and tokenizer
6
  model_id = 'akjindal53244/Llama-3.1-Storm-8B'
@@ -29,7 +34,7 @@ def generate_response(message, history):
29
  messages.append({"role": "user", "content": message})
30
 
31
  prompt = format_prompt(messages)
32
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
33
  generated_ids = model.generate(input_ids, max_new_tokens=256, temperature=0.7, do_sample=True, eos_token_id=tokenizer.eos_token_id)
34
  response = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
35
  return response.strip()
 
1
+ import subprocess
2
  import gradio as gr
3
  import torch
4
  from transformers import AutoTokenizer, LlamaForCausalLM
5
+ import spaces
6
+
7
+ # Install flash-attn with specific environment variable
8
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
 
10
  # Initialize model and tokenizer
11
  model_id = 'akjindal53244/Llama-3.1-Storm-8B'
 
34
  messages.append({"role": "user", "content": message})
35
 
36
  prompt = format_prompt(messages)
37
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
38
  generated_ids = model.generate(input_ids, max_new_tokens=256, temperature=0.7, do_sample=True, eos_token_id=tokenizer.eos_token_id)
39
  response = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
40
  return response.strip()