Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from googlesearch import search | |
import requests | |
from bs4 import BeautifulSoup | |
# Load the model and tokenizer | |
model_name = "akjindal53244/Llama-3.1-Storm-8B" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map="auto" | |
) | |
def fetch_web_content(url): | |
try: | |
response = requests.get(url, timeout=10) | |
soup = BeautifulSoup(response.text, 'html.parser') | |
return ' '.join(p.get_text() for p in soup.find_all('p')) | |
except Exception as e: | |
print(f"Error fetching {url}: {str(e)}") | |
return "Could not fetch content from this URL" | |
def web_search(query, num_results=3): | |
try: | |
results = [] | |
for j in search(query, num_results=num_results, advanced=True): | |
content = fetch_web_content(j.url) | |
results.append({ | |
"title": j.title, | |
"url": j.url, | |
"content": content[:1000] # Limit content length | |
}) | |
return results | |
except Exception as e: | |
print(f"Search error: {str(e)}") | |
return [] | |
def generate_text(prompt, max_length, temperature, use_web): | |
if use_web: | |
search_results = web_search(prompt) | |
context = "\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in search_results]) | |
prompt = f"Web Context:\n{context}\n\nUser Query: {prompt}" | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant with web search capabilities."}, | |
{"role": "user", "content": prompt} | |
] | |
formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_length, | |
do_sample=True, | |
temperature=temperature, | |
top_k=100, | |
top_p=0.95, | |
) | |
return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
# CSS and UI components | |
css = """ | |
:root { | |
--primary: #e94560; | |
--secondary: #1a1a2e; | |
--background: #16213e; | |
--text: #e0e0e0; | |
} | |
body { | |
background-color: var(--background); | |
color: var(--text); | |
font-family: 'Inter', sans-serif; | |
} | |
.container { | |
max-width: 1200px; | |
margin: auto; | |
padding: 20px; | |
} | |
.gradio-container { | |
background-color: var(--background); | |
border-radius: 15px; | |
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3); | |
} | |
.header { | |
background: linear-gradient(135deg, #0f3460 0%, #1a1a2e 100%); | |
padding: 2rem; | |
border-radius: 15px 15px 0 0; | |
text-align: center; | |
margin-bottom: 2rem; | |
} | |
.header h1 { | |
color: var(--primary); | |
font-size: 2.8rem; | |
margin-bottom: 1rem; | |
font-weight: 700; | |
} | |
.input-group, .output-group { | |
background-color: var(--secondary); | |
padding: 2rem; | |
border-radius: 12px; | |
margin-bottom: 2rem; | |
border: 1px solid #2d2d4d; | |
} | |
.generate-btn { | |
background: linear-gradient(135deg, var(--primary) 0%, #c81e45 100%) !important; | |
color: white !important; | |
border-radius: 8px !important; | |
padding: 12px 28px !important; | |
} | |
.example-prompts ul { | |
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); | |
gap: 1rem; | |
} | |
""" | |
example_prompts = [ | |
"Explain quantum computing in simple terms", | |
"Latest developments in AI research", | |
"Compare React and Vue.js frameworks", | |
"Recent advancements in cancer treatment" | |
] | |
with gr.Blocks(css=css, theme=gr.themes.Default()) as iface: | |
gr.HTML(""" | |
<div class="header"> | |
<h1>Llama-3.1-Storm-8B AI Assistant</h1> | |
<p>Enhanced with real-time web search capabilities</p> | |
</div> | |
""") | |
with gr.Tabs(): | |
with gr.TabItem("Chat Assistant"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Group(elem_classes="example-prompts"): | |
gr.Markdown("## Example Queries") | |
example_btns = [gr.Button(prompt) for prompt in example_prompts] | |
with gr.Group(elem_classes="input-group"): | |
prompt = gr.Textbox(label="Your Query", placeholder="Enter your question...", lines=5) | |
with gr.Row(): | |
web_search_toggle = gr.Checkbox(label="Enable Web Search", value=False) | |
num_results = gr.Slider(1, 5, value=3, step=1, label="Search Results") | |
with gr.Row(): | |
max_length = gr.Slider(32, 1024, value=256, step=32, label="Response Length") | |
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Creativity") | |
generate_btn = gr.Button("Generate Response", elem_classes="generate-btn") | |
with gr.Column(scale=2): | |
with gr.Group(elem_classes="output-group"): | |
output = gr.Textbox(label="Generated Response", lines=12) | |
with gr.Row(): | |
copy_btn = gr.Button("Copy") | |
clear_btn = gr.Button("Clear") | |
with gr.TabItem("Web Results"): | |
web_results = gr.JSON(label="Search Results Preview") | |
# Event handlers | |
generate_btn.click( | |
generate_text, | |
inputs=[prompt, max_length, temperature, web_search_toggle], | |
outputs=output | |
).then( | |
lambda q: web_search(q) if q else [], | |
inputs=[prompt], | |
outputs=web_results | |
) | |
for btn in example_btns: | |
btn.click(lambda x: x, inputs=[btn], outputs=[prompt]) | |
copy_btn.click( | |
None, | |
inputs=[output], | |
_js="(text) => { navigator.clipboard.writeText(text); return [] }" | |
) | |
clear_btn.click(lambda: "", outputs=[output]) | |
iface.launch() |