File size: 6,134 Bytes
985eabb
7e9dd79
985eabb
7e9dd79
b57d11e
b557580
 
7e9dd79
8b8d0cf
 
 
7e9dd79
 
8b8d0cf
 
 
 
b557580
 
 
 
 
b57d11e
 
b557580
 
 
 
 
 
 
 
 
 
 
 
 
b57d11e
 
b557580
 
7e9dd79
b557580
 
 
 
 
 
02a0e92
b557580
1f7ba92
02a0e92
b557580
1f7ba92
7e9dd79
b74d514
7e9dd79
 
1f7ba92
 
 
 
 
02a0e92
b74d514
7e9dd79
1e235cc
b57d11e
eac46c1
b557580
 
 
 
 
 
 
eac46c1
b557580
 
 
eac46c1
b557580
eac46c1
b557580
eac46c1
 
 
b557580
eac46c1
b557580
eac46c1
b557580
eac46c1
b557580
eac46c1
b557580
 
eac46c1
 
b557580
 
 
eac46c1
b557580
 
 
 
eac46c1
b557580
eac46c1
b557580
 
 
 
 
eac46c1
b557580
eac46c1
b557580
eac46c1
b557580
 
eac46c1
b557580
eac46c1
b557580
 
eac46c1
 
15967e4
eac46c1
b557580
 
 
b57d11e
eac46c1
 
b557580
 
eac46c1
b557580
b57d11e
eac46c1
b557580
b74d514
b557580
 
 
 
 
 
b57d11e
b74d514
b557580
b57d11e
b557580
b57d11e
b557580
b57d11e
b557580
b57d11e
b557580
 
 
 
eac46c1
b557580
 
b57d11e
 
 
 
eac46c1
b57d11e
 
b557580
b57d11e
b557580
 
 
 
 
 
 
 
 
 
 
 
 
b57d11e
b557580
b57d11e
 
b557580
b57d11e
b557580
eac46c1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from googlesearch import search
import requests
from bs4 import BeautifulSoup

# Load the model and tokenizer
model_name = "akjindal53244/Llama-3.1-Storm-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

def fetch_web_content(url):
    try:
        response = requests.get(url, timeout=10)
        soup = BeautifulSoup(response.text, 'html.parser')
        return ' '.join(p.get_text() for p in soup.find_all('p'))
    except Exception as e:
        print(f"Error fetching {url}: {str(e)}")
        return "Could not fetch content from this URL"

def web_search(query, num_results=3):
    try:
        results = []
        for j in search(query, num_results=num_results, advanced=True):
            content = fetch_web_content(j.url)
            results.append({
                "title": j.title,
                "url": j.url,
                "content": content[:1000]  # Limit content length
            })
        return results
    except Exception as e:
        print(f"Search error: {str(e)}")
        return []

@spaces.GPU(duration=120)
def generate_text(prompt, max_length, temperature, use_web):
    if use_web:
        search_results = web_search(prompt)
        context = "\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in search_results])
        prompt = f"Web Context:\n{context}\n\nUser Query: {prompt}"

    messages = [
        {"role": "system", "content": "You are a helpful assistant with web search capabilities."},
        {"role": "user", "content": prompt}
    ]
    
    formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_length,
        do_sample=True,
        temperature=temperature,
        top_k=100,
        top_p=0.95,
    )

    return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

# CSS and UI components
css = """
:root {
    --primary: #e94560;
    --secondary: #1a1a2e;
    --background: #16213e;
    --text: #e0e0e0;
}

body {
    background-color: var(--background);
    color: var(--text);
    font-family: 'Inter', sans-serif;
}

.container {
    max-width: 1200px;
    margin: auto;
    padding: 20px;
}

.gradio-container {
    background-color: var(--background);
    border-radius: 15px;
    box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3);
}

.header {
    background: linear-gradient(135deg, #0f3460 0%, #1a1a2e 100%);
    padding: 2rem;
    border-radius: 15px 15px 0 0;
    text-align: center;
    margin-bottom: 2rem;
}

.header h1 {
    color: var(--primary);
    font-size: 2.8rem;
    margin-bottom: 1rem;
    font-weight: 700;
}

.input-group, .output-group {
    background-color: var(--secondary);
    padding: 2rem;
    border-radius: 12px;
    margin-bottom: 2rem;
    border: 1px solid #2d2d4d;
}

.generate-btn {
    background: linear-gradient(135deg, var(--primary) 0%, #c81e45 100%) !important;
    color: white !important;
    border-radius: 8px !important;
    padding: 12px 28px !important;
}

.example-prompts ul {
    grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
    gap: 1rem;
}
"""

example_prompts = [
    "Explain quantum computing in simple terms",
    "Latest developments in AI research",
    "Compare React and Vue.js frameworks",
    "Recent advancements in cancer treatment"
]

with gr.Blocks(css=css, theme=gr.themes.Default()) as iface:
    gr.HTML("""
        <div class="header">
            <h1>Llama-3.1-Storm-8B AI Assistant</h1>
            <p>Enhanced with real-time web search capabilities</p>
        </div>
    """)

    with gr.Tabs():
        with gr.TabItem("Chat Assistant"):
            with gr.Row():
                with gr.Column(scale=3):
                    with gr.Group(elem_classes="example-prompts"):
                        gr.Markdown("## Example Queries")
                        example_btns = [gr.Button(prompt) for prompt in example_prompts]

                    with gr.Group(elem_classes="input-group"):
                        prompt = gr.Textbox(label="Your Query", placeholder="Enter your question...", lines=5)
                        
                        with gr.Row():
                            web_search_toggle = gr.Checkbox(label="Enable Web Search", value=False)
                            num_results = gr.Slider(1, 5, value=3, step=1, label="Search Results")
                        
                        with gr.Row():
                            max_length = gr.Slider(32, 1024, value=256, step=32, label="Response Length")
                            temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Creativity")
                        
                        generate_btn = gr.Button("Generate Response", elem_classes="generate-btn")

                with gr.Column(scale=2):
                    with gr.Group(elem_classes="output-group"):
                        output = gr.Textbox(label="Generated Response", lines=12)
                        with gr.Row():
                            copy_btn = gr.Button("Copy")
                            clear_btn = gr.Button("Clear")

        with gr.TabItem("Web Results"):
            web_results = gr.JSON(label="Search Results Preview")

    # Event handlers
    generate_btn.click(
        generate_text,
        inputs=[prompt, max_length, temperature, web_search_toggle],
        outputs=output
    ).then(
        lambda q: web_search(q) if q else [],
        inputs=[prompt],
        outputs=web_results
    )

    for btn in example_btns:
        btn.click(lambda x: x, inputs=[btn], outputs=[prompt])

    copy_btn.click(
        None,
        inputs=[output],
        _js="(text) => { navigator.clipboard.writeText(text); return [] }"
    )
    
    clear_btn.click(lambda: "", outputs=[output])

iface.launch()