Spaces:
Sleeping
Sleeping
""" | |
agent.py | |
This file defines the core logic for a sophisticated AI agent using LangGraph. | |
## MODIFICATION: This version introduces a 'multimodal_router' node. | |
This node intelligently inspects user input to identify, classify (using HEAD requests), | |
and pre-process URLs for images, audio, and video before the main LLM reasoning step. | |
""" | |
# ---------------------------------------------------------- | |
# Section 0: Imports and Configuration | |
# ---------------------------------------------------------- | |
import json | |
import os | |
import pickle | |
import re | |
import subprocess | |
import textwrap | |
import base64 | |
import functools | |
from io import BytesIO | |
from pathlib import Path | |
import tempfile | |
import yt_dlp | |
from pydub import AudioSegment | |
import speech_recognition as sr | |
import requests | |
from cachetools import TTLCache | |
from PIL import Image | |
from langchain.schema import Document | |
from langchain.tools.retriever import create_retriever_tool | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
from langchain_core.tools import Tool, tool | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from dotenv import load_dotenv | |
load_dotenv() | |
# --- Configuration and Caching (remains the same) --- | |
JSONL_PATH, FAISS_CACHE, EMBED_MODEL = Path("metadata.jsonl"), Path("faiss_index.pkl"), "sentence-transformers/all-mpnet-base-v2" | |
RETRIEVER_K, CACHE_TTL = 5, 600 | |
API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL) | |
def cached_get(key: str, fetch_fn): | |
if key in API_CACHE: return API_CACHE[key] | |
val = fetch_fn() | |
API_CACHE[key] = val | |
return val | |
# ---------------------------------------------------------- | |
# Section 2: Standalone Tool Functions (remains the same) | |
# ---------------------------------------------------------- | |
def python_repl(code: str) -> str: | |
"""Executes a string of Python code and returns the stdout/stderr.""" | |
# ... (implementation unchanged) | |
code = textwrap.dedent(code).strip() | |
try: | |
result = subprocess.run(["python", "-c", code], capture_output=True, text=True, timeout=10, check=False) | |
if result.returncode == 0: return f"Execution successful.\nSTDOUT:\n```\n{result.stdout}\n```" | |
else: return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```" | |
except subprocess.TimeoutExpired: return "Execution timed out (>10s)." | |
def process_youtube_video(url: str) -> str: | |
"""Downloads and processes a YouTube video, extracting audio and converting to text.""" | |
# ... (implementation unchanged) | |
try: | |
print(f"Processing YouTube video: {url}") | |
with tempfile.TemporaryDirectory() as temp_dir: | |
ydl_opts = { | |
'format': 'bestaudio/best', 'outtmpl': f'{temp_dir}/%(title)s.%(ext)s', | |
'postprocessors': [{'key': 'FFmpegExtractAudio', 'preferredcodec': 'wav'}], | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
info = ydl.extract_info(url, download=True) | |
title = info.get('title', 'Unknown') | |
audio_files = list(Path(temp_dir).glob("*.wav")) | |
if not audio_files: return "Error: Could not download audio from YouTube video" | |
r, transcript_parts = sr.Recognizer(), [] | |
audio = AudioSegment.from_wav(str(audio_files[0])).set_channels(1).set_frame_rate(16000) | |
chunks = [audio[i:i + 30000] for i in range(0, len(audio), 30000)] | |
for i, chunk in enumerate(chunks[:10]): | |
chunk_file = Path(temp_dir) / f"chunk_{i}.wav" | |
chunk.export(chunk_file, format="wav") | |
try: | |
with sr.AudioFile(str(chunk_file)) as source: | |
text = r.recognize_google(r.record(source)) | |
transcript_parts.append(text) | |
except (sr.UnknownValueError, sr.RequestError) as e: | |
transcript_parts.append(f"[Speech recognition error or unintelligible audio: {e}]") | |
return f"YouTube Video: {title}\n\nTranscript (first 5 minutes):\n{' '.join(transcript_parts)}" | |
except Exception as e: | |
print(f"Error processing YouTube video: {e}") | |
return f"Error processing YouTube video: {e}" | |
def process_audio_file(file_url: str) -> str: | |
"""Downloads and processes an audio file (MP3, WAV, etc.) and converts to text.""" | |
# ... (implementation unchanged) | |
try: | |
print(f"Processing audio file: {file_url}") | |
with tempfile.TemporaryDirectory() as temp_dir: | |
response = requests.get(file_url, timeout=30) | |
response.raise_for_status() | |
ext = os.path.splitext(file_url)[1][1:] or 'mp3' | |
audio_file = Path(temp_dir) / f"audio.{ext}" | |
with open(audio_file, 'wb') as f: f.write(response.content) | |
wav_file = Path(temp_dir) / "audio.wav" | |
AudioSegment.from_file(str(audio_file)).export(wav_file, format="wav") | |
r, transcript_parts = sr.Recognizer(), [] | |
audio = AudioSegment.from_wav(str(wav_file)).set_channels(1).set_frame_rate(16000) | |
chunks = [audio[i:i + 30000] for i in range(0, len(audio), 30000)] | |
for i, chunk in enumerate(chunks[:20]): | |
chunk_file = Path(temp_dir) / f"chunk_{i}.wav" | |
chunk.export(chunk_file, format="wav") | |
try: | |
with sr.AudioFile(str(chunk_file)) as source: | |
text = r.recognize_google(r.record(source)) | |
transcript_parts.append(text) | |
except (sr.UnknownValueError, sr.RequestError) as e: | |
transcript_parts.append(f"[Speech recognition error or unintelligible audio: {e}]") | |
return f"Audio file transcript:\n{' '.join(transcript_parts)}" | |
except Exception as e: | |
print(f"Error processing audio file: {e}") | |
return f"Error processing audio file: {e}" | |
def web_search_func(query: str, cache_func) -> str: | |
"""Performs a web search using Tavily and returns a compilation of results.""" | |
# ... (implementation unchanged) | |
key = f"web:{query}" | |
results = cache_func(key, lambda: TavilySearchResults(max_results=5).invoke(query)) | |
return "\n\n---\n\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in results]) | |
def wiki_search_func(query: str, cache_func) -> str: | |
"""Searches Wikipedia and returns the top 2 results.""" | |
# ... (implementation unchanged) | |
key = f"wiki:{query}" | |
docs = cache_func(key, lambda: WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load()) | |
return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\n\n{d.page_content}" for d in docs]) | |
def arxiv_search_func(query: str, cache_func) -> str: | |
"""Searches Arxiv for scientific papers and returns the top 2 results.""" | |
# ... (implementation unchanged) | |
key = f"arxiv:{query}" | |
docs = cache_func(key, lambda: ArxivLoader(query=query, load_max_docs=2).load()) | |
return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs]) | |
# ---------------------------------------------------------- | |
# Section 3: DYNAMIC SYSTEM PROMPT (remains the same) | |
# ---------------------------------------------------------- | |
SYSTEM_PROMPT_TEMPLATE = ( | |
"""You are an expert-level multimodal research assistant...""" # Unchanged | |
) | |
# ---------------------------------------------------------- | |
# Section 4: Factory Function for Agent Executor | |
# ---------------------------------------------------------- | |
def create_agent_executor(provider: str = "groq"): | |
""" | |
Factory function to create and compile the LangGraph agent executor. | |
""" | |
print(f"Initializing agent with provider: {provider}") | |
# Step 1: Build LLM (remains the same) | |
if provider == "groq": | |
llm = ChatGroq(model_name="llama-3.1-70b-vision-preview", temperature=0) | |
else: | |
raise ValueError(f"Provider '{provider}' not currently configured for this router.") | |
# Step 2: Build Retriever (remains the same, but will be called inside the router) | |
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL) | |
if FAISS_CACHE.exists(): | |
with open(FAISS_CACHE, "rb") as f: vector_store = pickle.load(f) | |
else: | |
# ... logic to build vector_store from JSONL or create empty ... | |
docs = [] | |
if JSONL_PATH.exists(): | |
docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))] | |
if not docs: | |
docs = [Document(page_content="Sample document", metadata={"source": "sample"})] | |
vector_store = FAISS.from_documents(docs, embeddings) | |
with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f) | |
retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K}) | |
# Step 3: Create the final list of tools (remains the same) | |
tools_list = [ | |
python_repl, process_youtube_video, process_audio_file, | |
Tool(name="web_search", func=functools.partial(web_search_func, cache_func=cached_get), description="Performs a web search using Tavily."), | |
Tool(name="wiki_search", func=functools.partial(wiki_search_func, cache_func=cached_get), description="Searches Wikipedia."), | |
Tool(name="arxiv_search", func=functools.partial(arxiv_search_func, cache_func=cached_get), description="Searches Arxiv for scientific papers."), | |
create_retriever_tool(retriever=retriever, name="retrieve_examples", description="Retrieve solved questions similar to the user's query."), | |
] | |
# Step 4: Format prompt and bind tools (remains the same) | |
tool_definitions = "\n".join([f"- `{tool.name}`: {tool.description}" for tool in tools_list]) | |
final_system_prompt = SYSTEM_PROMPT_TEMPLATE.format(tools=tool_definitions) | |
llm_with_tools = llm.bind_tools(tools_list) | |
# Step 5: Define Graph Nodes | |
## MODIFICATION: A new, powerful router node that replaces the previous pre-processing. | |
def multimodal_router(state: MessagesState): | |
""" | |
Inspects the user's message, classifies URLs, and prepares the state for the LLM. | |
This node acts as a central dispatcher. | |
""" | |
print("--- Entering Multimodal Router ---") | |
messages = state["messages"] | |
last_message = messages[-1] | |
# 1. Perform knowledge base retrieval first | |
# We consolidate this logic here from the old retriever_node | |
user_query_text = "" | |
if isinstance(last_message.content, str): | |
user_query_text = last_message.content | |
elif isinstance(last_message.content, list): # For multimodal messages | |
user_query_text = " ".join(item['text'] for item in last_message.content if item['type'] == 'text') | |
docs = retriever.invoke(user_query_text) | |
system_messages = [SystemMessage(content=final_system_prompt)] | |
if docs: | |
example_text = "\n\n---\n\n".join(d.page_content for d in docs) | |
system_messages.append(AIMessage(content=f"I have found {len(docs)} similar solved examples:\n\n{example_text}", name="ExampleRetriever")) | |
# 2. Extract and classify URLs | |
urls = re.findall(r'(https?://[^\s]+)', user_query_text) | |
image_processed = False | |
for url in urls: | |
try: | |
print(f"Routing URL: {url}") | |
# Simple classification first | |
if "youtube.com" in url or "youtu.be" in url: | |
system_messages.append(SystemMessage(content=f"[System Note: A YouTube URL has been detected. Use the 'process_youtube_video' tool if the user asks about it.]")) | |
continue | |
# Use a HEAD request for robust classification | |
headers = requests.head(url, timeout=5, allow_redirects=True).headers | |
content_type = headers.get('Content-Type', '') | |
if 'image/' in content_type and not image_processed: | |
print(f" -> Classified as Image. Processing for vision model.") | |
response = requests.get(url, timeout=10) | |
response.raise_for_status() | |
img = Image.open(BytesIO(response.content)) | |
buffered = BytesIO() | |
img.convert("RGB").save(buffered, format="JPEG") | |
b64_string = base64.b64encode(buffered.getvalue()).decode() | |
# Embed the image into the last message | |
new_content = [ | |
{"type": "text", "text": user_query_text}, | |
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}} | |
] | |
messages[-1] = HumanMessage(content=new_content) | |
image_processed = True # Process only the first image for now | |
elif 'audio/' in content_type: | |
print(f" -> Classified as Audio.") | |
system_messages.append(SystemMessage(content=f"[System Note: An audio URL has been detected. Use the 'process_audio_file' tool if the user asks about it.]")) | |
else: | |
print(f" -> Classified as Web Page/Other.") | |
except Exception as e: | |
print(f" -> Could not process URL {url}: {e}") | |
# Rebuild the final state | |
final_messages = system_messages + messages | |
return {"messages": final_messages} | |
def assistant_node(state: MessagesState): | |
result = llm_with_tools.invoke(state["messages"]) | |
return {"messages": [result]} | |
# Step 6: Build Graph | |
## MODIFICATION: The graph is now simpler and more robust. | |
builder = StateGraph(MessagesState) | |
builder.add_node("multimodal_router", multimodal_router) # The new, powerful starting node | |
builder.add_node("assistant", assistant_node) | |
builder.add_node("tools", ToolNode(tools_list)) | |
builder.add_edge(START, "multimodal_router") | |
builder.add_edge("multimodal_router", "assistant") | |
builder.add_conditional_edges("assistant", tools_condition, {"tools": "tools", "__end__": "__end__"}) | |
builder.add_edge("tools", "assistant") | |
agent_executor = builder.compile() | |
print("Agent Executor with Multimodal Router created successfully.") | |
return agent_executor |