sagar007 commited on
Commit
42dc3ae
·
verified ·
1 Parent(s): 8b8d0cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -102
app.py CHANGED
@@ -1,111 +1,44 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
4
- from transformers import AutoTokenizer, pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Load the model and tokenizer
7
  model_name = "akjindal53244/Llama-3.1-Storm-8B"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- pipe = pipeline(
10
- "text-generation",
11
- model=model_name,
12
  torch_dtype=torch.bfloat16,
13
  device_map="auto"
14
  )
15
 
16
- # HTML template
17
- HTML_TEMPLATE = """
18
- <style>
19
- body {
20
- background: linear-gradient(135deg, #f5f7fa, #c3cfe2);
21
- font-family: Arial, sans-serif;
22
- }
23
- #app-header {
24
- text-align: center;
25
- background: rgba(255, 255, 255, 0.8);
26
- padding: 20px;
27
- border-radius: 10px;
28
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
29
- position: relative;
30
- max-width: 800px;
31
- margin: 20px auto;
32
- }
33
- #app-header h1 {
34
- color: #4A90E2;
35
- font-size: 2em;
36
- margin-bottom: 10px;
37
- }
38
- .llama-image {
39
- position: relative;
40
- transition: transform 0.3s;
41
- display: inline-block;
42
- margin-top: 20px;
43
- }
44
- .llama-image:hover {
45
- transform: scale(1.05);
46
- }
47
- .llama-image img {
48
- width: 200px;
49
- border-radius: 10px;
50
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
51
- }
52
- .llama-description {
53
- position: absolute;
54
- bottom: -30px;
55
- left: 50%;
56
- transform: translateX(-50%);
57
- background-color: #4A90E2;
58
- color: white;
59
- padding: 5px 10px;
60
- border-radius: 5px;
61
- opacity: 0;
62
- transition: opacity 0.3s;
63
- white-space: nowrap;
64
- }
65
- .llama-image:hover .llama-description {
66
- opacity: 1;
67
- }
68
- .artifact {
69
- position: absolute;
70
- background: rgba(74, 144, 226, 0.1);
71
- border-radius: 50%;
72
- }
73
- .artifact.large {
74
- width: 300px;
75
- height: 300px;
76
- top: -50px;
77
- left: -150px;
78
- }
79
- .artifact.medium {
80
- width: 200px;
81
- height: 200px;
82
- bottom: -50px;
83
- right: -100px;
84
- }
85
- .artifact.small {
86
- width: 100px;
87
- height: 100px;
88
- top: 50%;
89
- left: 50%;
90
- transform: translate(-50%, -50%);
91
- }
92
- </style>
93
-
94
- <div id="app-header">
95
- <div class="artifact large"></div>
96
- <div class="artifact medium"></div>
97
- <div class="artifact small"></div>
98
-
99
- <h1>Llama-3.1-Storm-8B Text Generation</h1>
100
- <p>Generate text using the powerful Llama-3.1-Storm-8B model. Enter a prompt and let the AI create!</p>
101
-
102
- <div class="llama-image">
103
- <img src="https://cdn-uploads.huggingface.co/production/uploads/64c75c1237333ccfef30a602/tmOlbERGKP7JSODa6T06J.jpeg" alt="Llama">
104
- <div class="llama-description">Llama-3.1-Storm-8B Model</div>
105
- </div>
106
- </div>
107
- """
108
-
109
  @spaces.GPU(duration=120)
110
  def generate_text(prompt, max_length, temperature):
111
  messages = [
@@ -114,8 +47,10 @@ def generate_text(prompt, max_length, temperature):
114
  ]
115
  formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
116
 
117
- outputs = pipe(
118
- formatted_prompt,
 
 
119
  max_new_tokens=max_length,
120
  do_sample=True,
121
  temperature=temperature,
@@ -123,8 +58,9 @@ def generate_text(prompt, max_length, temperature):
123
  top_p=0.95,
124
  )
125
 
126
- return outputs[0]['generated_text'][len(formatted_prompt):]
127
 
 
128
  iface = gr.Interface(
129
  fn=generate_text,
130
  inputs=[
@@ -135,7 +71,12 @@ iface = gr.Interface(
135
  outputs=gr.Textbox(lines=10, label="Generated Text"),
136
  title="Llama-3.1-Storm-8B Text Generation",
137
  description="Enter a prompt to generate text using the Llama-3.1-Storm-8B model.",
138
- article=HTML_TEMPLATE
 
139
  )
140
 
141
- iface.launch()
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ # HTML template for custom UI
7
+ HTML_TEMPLATE = """
8
+ <style>
9
+ .llama-image {
10
+ display: flex;
11
+ justify-content: center;
12
+ margin-bottom: 20px;
13
+ }
14
+ .llama-image img {
15
+ max-width: 300px;
16
+ border-radius: 10px;
17
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
18
+ }
19
+ .llama-description {
20
+ text-align: center;
21
+ font-weight: bold;
22
+ margin-top: 10px;
23
+ }
24
+ </style>
25
+ <div class="llama-image">
26
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64c75c1237333ccfef30a602/tmOlbERGKP7JSODa6T06J.jpeg" alt="Llama">
27
+ <div class="llama-description">Llama-3.1-Storm-8B Model</div>
28
+ </div>
29
+ <h1>Llama-3.1-Storm-8B Text Generation</h1>
30
+ <p>Generate text using the powerful Llama-3.1-Storm-8B model. Enter a prompt and let the AI create!</p>
31
+ """
32
 
33
  # Load the model and tokenizer
34
  model_name = "akjindal53244/Llama-3.1-Storm-8B"
35
  tokenizer = AutoTokenizer.from_pretrained(model_name)
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_name,
 
38
  torch_dtype=torch.bfloat16,
39
  device_map="auto"
40
  )
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  @spaces.GPU(duration=120)
43
  def generate_text(prompt, max_length, temperature):
44
  messages = [
 
47
  ]
48
  formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
49
 
50
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
51
+
52
+ outputs = model.generate(
53
+ **inputs,
54
  max_new_tokens=max_length,
55
  do_sample=True,
56
  temperature=temperature,
 
58
  top_p=0.95,
59
  )
60
 
61
+ return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
62
 
63
+ # Create Gradio interface
64
  iface = gr.Interface(
65
  fn=generate_text,
66
  inputs=[
 
71
  outputs=gr.Textbox(lines=10, label="Generated Text"),
72
  title="Llama-3.1-Storm-8B Text Generation",
73
  description="Enter a prompt to generate text using the Llama-3.1-Storm-8B model.",
74
+ article=None,
75
+ css=".gradio-container {max-width: 800px; margin: auto;}",
76
  )
77
 
78
+ iface.launch(
79
+ additional_inputs=[
80
+ gr.HTML(HTML_TEMPLATE)
81
+ ]
82
+ )