File size: 3,431 Bytes
985eabb
1e235cc
985eabb
1e235cc
cc1b568
8b8d0cf
 
 
1e235cc
 
8b8d0cf
 
 
 
1e235cc
1f7ba92
02a0e92
1f7ba92
 
02a0e92
1f7ba92
 
1e235cc
 
 
 
1f7ba92
 
 
 
 
02a0e92
 
1e235cc
 
d692c8b
1e235cc
 
d692c8b
 
 
1e235cc
d692c8b
 
 
1e235cc
d692c8b
 
 
 
1e235cc
 
d692c8b
 
 
 
 
 
1e235cc
d692c8b
 
 
 
1e235cc
d692c8b
 
1e235cc
d692c8b
 
1e235cc
d692c8b
1e235cc
d692c8b
 
 
 
 
1e235cc
d692c8b
 
 
1e235cc
d692c8b
 
 
 
 
 
 
 
 
1e235cc
d692c8b
 
1e235cc
 
1f7ba92
d692c8b
b3f66be
d692c8b
 
 
b3f66be
 
d692c8b
b3f66be
d692c8b
 
d0ce6f0
d692c8b
 
b3f66be
 
d692c8b
 
 
 
 
 
1f7ba92
d692c8b
1e235cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the model and tokenizer
model_name = "akjindal53244/Llama-3.1-Storm-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

@spaces.GPU(duration=120)
def generate_text(prompt, max_length, temperature):
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]
    formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_length,
        do_sample=True,
        temperature=temperature,
        top_k=100,
        top_p=0.95,
    )
    
    return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

# Custom CSS
css = """
body {
    background-color: #1a1a2e;
    color: #e0e0e0;
    font-family: 'Arial', sans-serif;
}
.container {
    max-width: 900px;
    margin: auto;
    padding: 20px;
}
.gradio-container {
    background-color: #16213e;
    border-radius: 15px;
    box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.header {
    background-color: #0f3460;
    padding: 20px;
    border-radius: 15px 15px 0 0;
    text-align: center;
    margin-bottom: 20px;
}
.header h1 {
    color: #e94560;
    font-size: 2.5em;
    margin-bottom: 10px;
}
.header p {
    color: #a0a0a0;
}
.header img {
    max-width: 200px;
    border-radius: 10px;
    margin-top: 15px;
}
.input-group, .output-group {
    background-color: #1a1a2e;
    padding: 20px;
    border-radius: 10px;
    margin-bottom: 20px;
}
.input-group label, .output-group label {
    color: #e94560;
    font-weight: bold;
}
.generate-btn {
    background-color: #e94560 !important;
    color: white !important;
    border: none !important;
    border-radius: 5px !important;
    padding: 10px 20px !important;
    font-size: 16px !important;
    cursor: pointer !important;
    transition: background-color 0.3s ease !important;
}
.generate-btn:hover {
    background-color: #c81e45 !important;
}
"""

# Gradio interface
with gr.Blocks(css=css) as iface:
    gr.HTML(
        """
        <div class="header">
            <h1>Llama-3.1-Storm-8B Text Generation</h1>
            <p>Generate text using the powerful Llama-3.1-Storm-8B model. Enter a prompt and let the AI create!</p>
            <img src="https://cdn-uploads.huggingface.co/production/uploads/64c75c1237333ccfef30a602/tmOlbERGKP7JSODa6T06J.jpeg" alt="Llama">
        </div>
        """
    )
    
    with gr.Group(elem_classes="input-group"):
        prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=5)
        max_length = gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length")
        temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
        generate_btn = gr.Button("Generate", elem_classes="generate-btn")

    with gr.Group(elem_classes="output-group"):
        output = gr.Textbox(label="Generated Text", lines=10)

    generate_btn.click(generate_text, inputs=[prompt, max_length, temperature], outputs=output)

# Launch the app
iface.launch()