Spaces:
Runtime error
Runtime error
from __future__ import annotations as _annotations | |
import json | |
import os | |
from contextlib import asynccontextmanager | |
from dataclasses import dataclass | |
from typing import AsyncGenerator | |
import asyncpg | |
import gradio as gr | |
import numpy as np | |
import pydantic_core | |
from gradio_webrtc import ( | |
AdditionalOutputs, | |
ReplyOnPause, | |
WebRTC, | |
audio_to_bytes, | |
get_twilio_turn_credentials, | |
) | |
from groq import Groq | |
from openai import AsyncOpenAI | |
from pydantic import BaseModel | |
from pydantic_ai import RunContext | |
from pydantic_ai.agent import Agent | |
from pydantic_ai.messages import ModelStructuredResponse, ModelTextResponse, ToolReturn | |
DOCS = json.load(open("gradio_docs.json")) | |
groq_client = Groq() | |
openai = AsyncOpenAI() | |
class Deps: | |
openai: AsyncOpenAI | |
pool: asyncpg.Pool | |
SYSTEM_PROMPT = ( | |
"You are an assistant designed to help users answer questions about Gradio. " | |
"You have a retrival tool that can provide relevant documentation sections based on the user query. " | |
"Be curteous and helpful to the user but feel free to refuse answering questions that are not about Gradio. " | |
) | |
agent = Agent( | |
"openai:gpt-4o", | |
deps_type=Deps, | |
system_prompt=SYSTEM_PROMPT, | |
) | |
class RetrievalResult(BaseModel): | |
content: str | |
ids: list[int] | |
async def database_connect( | |
create_db: bool = False, | |
) -> AsyncGenerator[asyncpg.Pool, None]: | |
server_dsn, database = ( | |
os.getenv("DATABASE_URL"), | |
"gradio_ai_rag", | |
) | |
if create_db: | |
conn = await asyncpg.connect(server_dsn) | |
try: | |
db_exists = await conn.fetchval( | |
"SELECT 1 FROM pg_database WHERE datname = $1", database | |
) | |
if not db_exists: | |
await conn.execute(f"CREATE DATABASE {database}") | |
finally: | |
await conn.close() | |
pool = await asyncpg.create_pool(f"{server_dsn}/{database}") | |
try: | |
yield pool | |
finally: | |
await pool.close() | |
async def retrieve(context: RunContext[Deps], search_query: str) -> str: | |
"""Retrieve documentation sections based on a search query. | |
Args: | |
context: The call context. | |
search_query: The search query. | |
""" | |
print(f"create embedding for {search_query}") | |
embedding = await context.deps.openai.embeddings.create( | |
input=search_query, | |
model="text-embedding-3-small", | |
) | |
assert ( | |
len(embedding.data) == 1 | |
), f"Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}" | |
embedding = embedding.data[0].embedding | |
embedding_json = pydantic_core.to_json(embedding).decode() | |
rows = await context.deps.pool.fetch( | |
"SELECT id, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8", | |
embedding_json, | |
) | |
content = "\n\n".join(f'# {row["title"]}\n{row["content"]}\n' for row in rows) | |
ids = [row["id"] for row in rows] | |
return RetrievalResult(content=content, ids=ids).model_dump_json() | |
async def stream_from_agent( | |
audio: tuple[int, np.ndarray], chatbot: list[dict], past_messages: list | |
): | |
question = groq_client.audio.transcriptions.create( | |
file=("audio-file.mp3", audio_to_bytes(audio)), | |
model="whisper-large-v3-turbo", | |
response_format="verbose_json", | |
).text | |
print("text", question) | |
chatbot.append({"role": "user", "content": question}) | |
yield AdditionalOutputs(chatbot, gr.skip()) | |
async with database_connect(False) as pool: | |
deps = Deps(openai=openai, pool=pool) | |
async with agent.run_stream( | |
question, deps=deps, message_history=past_messages | |
) as result: | |
for message in result.new_messages(): | |
past_messages.append(message) | |
if isinstance(message, ModelStructuredResponse): | |
for call in message.calls: | |
gr_message = { | |
"role": "assistant", | |
"content": "", | |
"metadata": { | |
"title": "π Retrieving relevant docs", | |
"id": call.tool_id, | |
}, | |
} | |
chatbot.append(gr_message) | |
if isinstance(message, ToolReturn): | |
for gr_message in chatbot: | |
if ( | |
gr_message.get("metadata", {}).get("id", "") | |
== message.tool_id | |
): | |
paths = [] | |
for d in DOCS: | |
tool_result = RetrievalResult.model_validate_json( | |
message.content | |
) | |
if d["id"] in tool_result.ids: | |
paths.append(d["path"]) | |
gr_message["content"] = ( | |
f"Relevant Context:\n {'\n'.join(list(set(paths)))}" | |
) | |
yield AdditionalOutputs(chatbot, gr.skip()) | |
chatbot.append({"role": "assistant", "content": ""}) | |
async for message in result.stream_text(): | |
chatbot[-1]["content"] = message | |
yield AdditionalOutputs(chatbot, gr.skip()) | |
data = await result.get_data() | |
past_messages.append(ModelTextResponse(content=data)) | |
yield AdditionalOutputs(gr.skip(), past_messages) | |
with gr.Blocks() as demo: | |
placeholder = """ | |
<div style="display: flex; justify-content: center; align-items: center; gap: 1rem; padding: 1rem; width: 100%"> | |
<img src="/gradio_api/file=logo.svg" style="max-width: 200px; height: auto"> | |
<div> | |
<h1 style="margin: 0 0 1rem 0">Chat with Gradio Docs π£οΈ</h1> | |
<h3 style="margin: 0 0 0.5rem 0"> | |
Simple RAG agent over Gradio docs built with Pydantic AI. | |
</h3> | |
<h3 style="margin: 0"> | |
Ask any question about Gradio with your natural voice and get an answer! | |
</h3> | |
</div> | |
</div> | |
""" | |
past_messages = gr.State([]) | |
chatbot = gr.Chatbot( | |
label="Gradio Docs Bot", | |
type="messages", | |
placeholder=placeholder, | |
avatar_images=(None, "logo.svg"), | |
) | |
audio = WebRTC( | |
label="Talk with the Agent", | |
modality="audio", | |
rtc_configuration=get_twilio_turn_credentials(), | |
mode="send", | |
) | |
audio.stream( | |
ReplyOnPause(stream_from_agent), | |
inputs=[audio, chatbot, past_messages], | |
outputs=[audio], | |
) | |
audio.on_additional_outputs( | |
lambda c, s: (c, s), | |
outputs=[chatbot, past_messages], | |
queue=False, | |
show_progress="hidden", | |
) | |
if __name__ == "__main__": | |
demo.launch(allowed_paths=["logo.svg"]) | |