sagar007 commited on
Commit
d0ce6f0
·
verified ·
1 Parent(s): 42dc3ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -56
app.py CHANGED
@@ -1,45 +1,26 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
-
6
- # HTML template for custom UI
7
- HTML_TEMPLATE = """
8
- <style>
9
- .llama-image {
10
- display: flex;
11
- justify-content: center;
12
- margin-bottom: 20px;
13
- }
14
- .llama-image img {
15
- max-width: 300px;
16
- border-radius: 10px;
17
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
18
- }
19
- .llama-description {
20
- text-align: center;
21
- font-weight: bold;
22
- margin-top: 10px;
23
- }
24
- </style>
25
- <div class="llama-image">
26
- <img src="https://cdn-uploads.huggingface.co/production/uploads/64c75c1237333ccfef30a602/tmOlbERGKP7JSODa6T06J.jpeg" alt="Llama">
27
- <div class="llama-description">Llama-3.1-Storm-8B Model</div>
28
- </div>
29
- <h1>Llama-3.1-Storm-8B Text Generation</h1>
30
- <p>Generate text using the powerful Llama-3.1-Storm-8B model. Enter a prompt and let the AI create!</p>
31
- """
32
 
33
  # Load the model and tokenizer
34
  model_name = "akjindal53244/Llama-3.1-Storm-8B"
35
  tokenizer = AutoTokenizer.from_pretrained(model_name)
36
- model = AutoModelForCausalLM.from_pretrained(
37
- model_name,
 
38
  torch_dtype=torch.bfloat16,
39
  device_map="auto"
40
  )
41
 
42
- @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
43
  def generate_text(prompt, max_length, temperature):
44
  messages = [
45
  {"role": "system", "content": "You are a helpful assistant."},
@@ -47,10 +28,8 @@ def generate_text(prompt, max_length, temperature):
47
  ]
48
  formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
49
 
50
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
51
-
52
- outputs = model.generate(
53
- **inputs,
54
  max_new_tokens=max_length,
55
  do_sample=True,
56
  temperature=temperature,
@@ -58,25 +37,20 @@ def generate_text(prompt, max_length, temperature):
58
  top_p=0.95,
59
  )
60
 
61
- return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
62
 
63
- # Create Gradio interface
64
- iface = gr.Interface(
65
- fn=generate_text,
66
- inputs=[
67
- gr.Textbox(lines=5, label="Prompt"),
68
- gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length"),
69
- gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
70
- ],
71
- outputs=gr.Textbox(lines=10, label="Generated Text"),
72
- title="Llama-3.1-Storm-8B Text Generation",
73
- description="Enter a prompt to generate text using the Llama-3.1-Storm-8B model.",
74
- article=None,
75
- css=".gradio-container {max-width: 800px; margin: auto;}",
76
- )
77
 
78
- iface.launch(
79
- additional_inputs=[
80
- gr.HTML(HTML_TEMPLATE)
81
- ]
82
- )
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import AutoTokenizer, pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # Load the model and tokenizer
6
  model_name = "akjindal53244/Llama-3.1-Storm-8B"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ pipe = pipeline(
9
+ "text-generation",
10
+ model=model_name,
11
  torch_dtype=torch.bfloat16,
12
  device_map="auto"
13
  )
14
 
15
+ # HTML content
16
+ HTML_CONTENT = """
17
+ <h1>Llama-3.1-Storm-8B Text Generation</h1>
18
+ <p>Generate text using the powerful Llama-3.1-Storm-8B model. Enter a prompt and let the AI create!</p>
19
+ <div class="llama-image">
20
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64c75c1237333ccfef30a602/tmOlbERGKP7JSODa6T06J.jpeg" alt="Llama" style="width:200px; border-radius:10px;">
21
+ </div>
22
+ """
23
+
24
  def generate_text(prompt, max_length, temperature):
25
  messages = [
26
  {"role": "system", "content": "You are a helpful assistant."},
 
28
  ]
29
  formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
30
 
31
+ outputs = pipe(
32
+ formatted_prompt,
 
 
33
  max_new_tokens=max_length,
34
  do_sample=True,
35
  temperature=temperature,
 
37
  top_p=0.95,
38
  )
39
 
40
+ return outputs[0]['generated_text'][len(formatted_prompt):]
41
 
42
+ with gr.Blocks() as demo:
43
+ gr.HTML(HTML_CONTENT)
44
+ with gr.Row():
45
+ with gr.Column(scale=2):
46
+ prompt = gr.Textbox(label="Prompt", lines=5)
47
+ max_length = gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length")
48
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
49
+ submit_button = gr.Button("Generate")
50
+ with gr.Column(scale=2):
51
+ output = gr.Textbox(label="Generated Text", lines=10)
52
+
53
+ submit_button.click(generate_text, inputs=[prompt, max_length, temperature], outputs=[output])
 
 
54
 
55
+ if __name__ == "__main__":
56
+ demo.launch()