|
import gradio as gr |
|
import os |
|
from huggingface_hub import InferenceClient |
|
import spaces |
|
from prompts import SYSTEM_PROMPT |
|
|
|
client = InferenceClient( |
|
"meta-llama/Llama-3.3-70B-Instruct", |
|
provider="cerebras", |
|
token=os.getenv("HF_TOKEN"), |
|
) |
|
|
|
@spaces.GPU |
|
def chat_with_llama(message, history, system_prompt): |
|
messages = [] |
|
|
|
if system_prompt and system_prompt.strip(): |
|
messages.append({"role": "system", "content": system_prompt.strip()}) |
|
|
|
for msg in history: |
|
if isinstance(msg, dict): |
|
if msg["role"] == "user": |
|
messages.append({"role": "user", "content": msg["content"]}) |
|
elif msg["role"] == "assistant": |
|
messages.append({"role": "assistant", "content": msg["content"]}) |
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
generation_params = { |
|
"messages": messages, |
|
"stream": True, |
|
} |
|
|
|
response = "" |
|
try: |
|
for chunk in client.chat_completion(**generation_params): |
|
if hasattr(chunk, 'choices') and len(chunk.choices) > 0: |
|
delta = chunk.choices[0].delta |
|
if hasattr(delta, 'content') and delta.content: |
|
response += delta.content |
|
yield response |
|
except Exception as e: |
|
yield f"Error: {str(e)}" |
|
|
|
with gr.Blocks(title="AI Chat Assistant", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# AI Chat Assistant") |
|
gr.Markdown("Chat with Llama-3.3-70B powered by Cerebras") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
chatbot = gr.Chatbot( |
|
height=500, |
|
type='messages', |
|
show_copy_button=True |
|
) |
|
|
|
msg = gr.Textbox( |
|
label="Your message", |
|
placeholder="Type your message here...", |
|
lines=2 |
|
) |
|
|
|
with gr.Row(): |
|
send_btn = gr.Button("Send", variant="primary", size="lg") |
|
clear_btn = gr.Button("Clear", size="lg") |
|
|
|
with gr.Column(): |
|
system_prompt = gr.Textbox( |
|
label="System Prompt", |
|
value=SYSTEM_PROMPT, |
|
lines=10, |
|
show_copy_button=True |
|
) |
|
|
|
def respond(message, history, system_prompt): |
|
new_history = history + [{"role": "user", "content": message}] |
|
for response in chat_with_llama(message, history, system_prompt): |
|
yield new_history + [{"role": "assistant", "content": response}], "" |
|
|
|
def clear_chat(): |
|
return [], "" |
|
|
|
msg.submit( |
|
fn=respond, |
|
inputs=[msg, chatbot, system_prompt], |
|
outputs=[chatbot, msg], |
|
show_progress=True |
|
) |
|
|
|
send_btn.click( |
|
fn=respond, |
|
inputs=[msg, chatbot, system_prompt], |
|
outputs=[chatbot, msg], |
|
show_progress=True |
|
) |
|
|
|
clear_btn.click( |
|
fn=clear_chat, |
|
outputs=[chatbot, msg], |
|
show_progress=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |