lama_storm_8b / app.py
sagar007's picture
Update app.py
b57d11e verified
raw
history blame
6.13 kB
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from googlesearch import search
import requests
from bs4 import BeautifulSoup
# Load the model and tokenizer
model_name = "akjindal53244/Llama-3.1-Storm-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
def fetch_web_content(url):
try:
response = requests.get(url, timeout=10)
soup = BeautifulSoup(response.text, 'html.parser')
return ' '.join(p.get_text() for p in soup.find_all('p'))
except Exception as e:
print(f"Error fetching {url}: {str(e)}")
return "Could not fetch content from this URL"
def web_search(query, num_results=3):
try:
results = []
for j in search(query, num_results=num_results, advanced=True):
content = fetch_web_content(j.url)
results.append({
"title": j.title,
"url": j.url,
"content": content[:1000] # Limit content length
})
return results
except Exception as e:
print(f"Search error: {str(e)}")
return []
@spaces.GPU(duration=120)
def generate_text(prompt, max_length, temperature, use_web):
if use_web:
search_results = web_search(prompt)
context = "\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in search_results])
prompt = f"Web Context:\n{context}\n\nUser Query: {prompt}"
messages = [
{"role": "system", "content": "You are a helpful assistant with web search capabilities."},
{"role": "user", "content": prompt}
]
formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=True,
temperature=temperature,
top_k=100,
top_p=0.95,
)
return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
# CSS and UI components
css = """
:root {
--primary: #e94560;
--secondary: #1a1a2e;
--background: #16213e;
--text: #e0e0e0;
}
body {
background-color: var(--background);
color: var(--text);
font-family: 'Inter', sans-serif;
}
.container {
max-width: 1200px;
margin: auto;
padding: 20px;
}
.gradio-container {
background-color: var(--background);
border-radius: 15px;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3);
}
.header {
background: linear-gradient(135deg, #0f3460 0%, #1a1a2e 100%);
padding: 2rem;
border-radius: 15px 15px 0 0;
text-align: center;
margin-bottom: 2rem;
}
.header h1 {
color: var(--primary);
font-size: 2.8rem;
margin-bottom: 1rem;
font-weight: 700;
}
.input-group, .output-group {
background-color: var(--secondary);
padding: 2rem;
border-radius: 12px;
margin-bottom: 2rem;
border: 1px solid #2d2d4d;
}
.generate-btn {
background: linear-gradient(135deg, var(--primary) 0%, #c81e45 100%) !important;
color: white !important;
border-radius: 8px !important;
padding: 12px 28px !important;
}
.example-prompts ul {
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
gap: 1rem;
}
"""
example_prompts = [
"Explain quantum computing in simple terms",
"Latest developments in AI research",
"Compare React and Vue.js frameworks",
"Recent advancements in cancer treatment"
]
with gr.Blocks(css=css, theme=gr.themes.Default()) as iface:
gr.HTML("""
<div class="header">
<h1>Llama-3.1-Storm-8B AI Assistant</h1>
<p>Enhanced with real-time web search capabilities</p>
</div>
""")
with gr.Tabs():
with gr.TabItem("Chat Assistant"):
with gr.Row():
with gr.Column(scale=3):
with gr.Group(elem_classes="example-prompts"):
gr.Markdown("## Example Queries")
example_btns = [gr.Button(prompt) for prompt in example_prompts]
with gr.Group(elem_classes="input-group"):
prompt = gr.Textbox(label="Your Query", placeholder="Enter your question...", lines=5)
with gr.Row():
web_search_toggle = gr.Checkbox(label="Enable Web Search", value=False)
num_results = gr.Slider(1, 5, value=3, step=1, label="Search Results")
with gr.Row():
max_length = gr.Slider(32, 1024, value=256, step=32, label="Response Length")
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Creativity")
generate_btn = gr.Button("Generate Response", elem_classes="generate-btn")
with gr.Column(scale=2):
with gr.Group(elem_classes="output-group"):
output = gr.Textbox(label="Generated Response", lines=12)
with gr.Row():
copy_btn = gr.Button("Copy")
clear_btn = gr.Button("Clear")
with gr.TabItem("Web Results"):
web_results = gr.JSON(label="Search Results Preview")
# Event handlers
generate_btn.click(
generate_text,
inputs=[prompt, max_length, temperature, web_search_toggle],
outputs=output
).then(
lambda q: web_search(q) if q else [],
inputs=[prompt],
outputs=web_results
)
for btn in example_btns:
btn.click(lambda x: x, inputs=[btn], outputs=[prompt])
copy_btn.click(
None,
inputs=[output],
_js="(text) => { navigator.clipboard.writeText(text); return [] }"
)
clear_btn.click(lambda: "", outputs=[output])
iface.launch()