Abbasid's picture
Update agent.py
bd93d23 verified
raw
history blame
16.9 kB
"""
agent.py
This file defines the core logic for a sophisticated AI agent using LangGraph.
This version includes proper multimodal support for images, YouTube videos, and audio files.
"""
# ----------------------------------------------------------
# Section 0: Imports and Configuration
# ----------------------------------------------------------
import json
import os
import pickle
import re
import subprocess
import textwrap
import base64
import functools
from io import BytesIO
from pathlib import Path
import tempfile
import yt_dlp
from pydub import AudioSegment
import speech_recognition as sr
import requests
from cachetools import TTLCache
from PIL import Image
from langchain.schema import Document
from langchain.tools.retriever import create_retriever_tool
from langchain_community.vectorstores import FAISS
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import Tool, tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from dotenv import load_dotenv
load_dotenv()
# --- Configuration and Caching ---
JSONL_PATH, FAISS_CACHE, EMBED_MODEL = Path("metadata.jsonl"), Path("faiss_index.pkl"), "sentence-transformers/all-mpnet-base-v2"
RETRIEVER_K, CACHE_TTL = 5, 600
API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
def cached_get(key: str, fetch_fn):
if key in API_CACHE: return API_CACHE[key]
val = fetch_fn()
API_CACHE[key] = val
return val
# ----------------------------------------------------------
# Section 2: Standalone Tool Functions
# ----------------------------------------------------------
@tool
def python_repl(code: str) -> str:
"""Executes a string of Python code and returns the stdout/stderr."""
code = textwrap.dedent(code).strip()
try:
result = subprocess.run(["python", "-c", code], capture_output=True, text=True, timeout=10, check=False)
if result.returncode == 0: return f"Execution successful.\nSTDOUT:\n```\n{result.stdout}\n```"
else: return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```"
except subprocess.TimeoutExpired: return "Execution timed out (>10s)."
def describe_image_func(image_source: str, vision_llm_instance) -> str:
"""Describes an image from a local file path or a URL using a provided vision LLM."""
try:
print(f"Processing image: {image_source}")
# Download and process image
if image_source.startswith("http"):
response = requests.get(image_source, timeout=10)
response.raise_for_status()
img = Image.open(BytesIO(response.content))
else:
img = Image.open(image_source)
# Convert to base64
buffered = BytesIO()
img.convert("RGB").save(buffered, format="JPEG")
b64_string = base64.b64encode(buffered.getvalue()).decode()
# Create multimodal message
msg = HumanMessage(content=[
{"type": "text", "text": "Describe this image in detail. Include all objects, people, text, colors, setting, and any other relevant information you can see."},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}
])
result = vision_llm_instance.invoke([msg])
return f"Image description: {result.content}"
except Exception as e:
print(f"Error in describe_image_func: {e}")
return f"Error processing image: {e}"
@tool
def process_youtube_video(url: str) -> str:
"""Downloads and processes a YouTube video, extracting audio and converting to text."""
try:
print(f"Processing YouTube video: {url}")
# Create temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
# Download audio from YouTube video
ydl_opts = {
'format': 'bestaudio/best',
'outtmpl': f'{temp_dir}/%(title)s.%(ext)s',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'wav',
}],
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info = ydl.extract_info(url, download=True)
title = info.get('title', 'Unknown')
# Find the downloaded audio file
audio_files = list(Path(temp_dir).glob("*.wav"))
if not audio_files:
return "Error: Could not download audio from YouTube video"
audio_file = audio_files[0]
# Convert audio to text using speech recognition
r = sr.Recognizer()
# Load audio file
audio = AudioSegment.from_wav(str(audio_file))
# Convert to mono and set sample rate
audio = audio.set_channels(1)
audio = audio.set_frame_rate(16000)
# Convert to smaller chunks for processing (30 seconds each)
chunk_length_ms = 30000
chunks = [audio[i:i + chunk_length_ms] for i in range(0, len(audio), chunk_length_ms)]
transcript_parts = []
for i, chunk in enumerate(chunks[:10]): # Limit to first 5 minutes
chunk_file = Path(temp_dir) / f"chunk_{i}.wav"
chunk.export(chunk_file, format="wav")
try:
with sr.AudioFile(str(chunk_file)) as source:
audio_data = r.record(source)
text = r.recognize_google(audio_data)
transcript_parts.append(text)
except sr.UnknownValueError:
transcript_parts.append("[Unintelligible audio]")
except sr.RequestError as e:
transcript_parts.append(f"[Speech recognition error: {e}]")
transcript = " ".join(transcript_parts)
return f"YouTube Video: {title}\n\nTranscript (first 5 minutes):\n{transcript}"
except Exception as e:
print(f"Error processing YouTube video: {e}")
return f"Error processing YouTube video: {e}"
@tool
def process_audio_file(file_url: str) -> str:
"""Downloads and processes an audio file (MP3, WAV, etc.) and converts to text."""
try:
print(f"Processing audio file: {file_url}")
with tempfile.TemporaryDirectory() as temp_dir:
# Download audio file
response = requests.get(file_url, timeout=30)
response.raise_for_status()
# Determine file extension from URL or content type
if file_url.lower().endswith('.mp3'):
ext = 'mp3'
elif file_url.lower().endswith('.wav'):
ext = 'wav'
else:
content_type = response.headers.get('content-type', '')
if 'mp3' in content_type:
ext = 'mp3'
elif 'wav' in content_type:
ext = 'wav'
else:
ext = 'mp3' # Default assumption
audio_file = Path(temp_dir) / f"audio.{ext}"
with open(audio_file, 'wb') as f:
f.write(response.content)
# Convert to WAV if necessary
if ext != 'wav':
audio = AudioSegment.from_file(str(audio_file))
wav_file = Path(temp_dir) / "audio.wav"
audio.export(wav_file, format="wav")
audio_file = wav_file
# Convert audio to text
r = sr.Recognizer()
# Load and process audio
audio = AudioSegment.from_wav(str(audio_file))
audio = audio.set_channels(1).set_frame_rate(16000)
# Process in chunks
chunk_length_ms = 30000
chunks = [audio[i:i + chunk_length_ms] for i in range(0, len(audio), chunk_length_ms)]
transcript_parts = []
for i, chunk in enumerate(chunks[:20]): # Limit to first 10 minutes
chunk_file = Path(temp_dir) / f"chunk_{i}.wav"
chunk.export(chunk_file, format="wav")
try:
with sr.AudioFile(str(chunk_file)) as source:
audio_data = r.record(source)
text = r.recognize_google(audio_data)
transcript_parts.append(text)
except sr.UnknownValueError:
transcript_parts.append("[Unintelligible audio]")
except sr.RequestError as e:
transcript_parts.append(f"[Speech recognition error: {e}]")
transcript = " ".join(transcript_parts)
return f"Audio file transcript:\n{transcript}"
except Exception as e:
print(f"Error processing audio file: {e}")
return f"Error processing audio file: {e}"
def web_search_func(query: str, cache_func) -> str:
"""Performs a web search using Tavily and returns a compilation of results."""
key = f"web:{query}"
results = cache_func(key, lambda: TavilySearchResults(max_results=5).invoke(query))
return "\n\n---\n\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in results])
def wiki_search_func(query: str, cache_func) -> str:
"""Searches Wikipedia and returns the top 2 results."""
key = f"wiki:{query}"
docs = cache_func(key, lambda: WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load())
return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\n\n{d.page_content}" for d in docs])
def arxiv_search_func(query: str, cache_func) -> str:
"""Searches Arxiv for scientific papers and returns the top 2 results."""
key = f"arxiv:{query}"
docs = cache_func(key, lambda: ArxivLoader(query=query, load_max_docs=2).load())
return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs])
# ----------------------------------------------------------
# Section 3: DYNAMIC SYSTEM PROMPT
# ----------------------------------------------------------
SYSTEM_PROMPT_TEMPLATE = (
"""You are an expert-level multimodal research assistant. Your goal is to answer the user's question accurately using all available tools.
**CRITICAL INSTRUCTIONS:**
1. **USE YOUR TOOLS:** You have been given a set of tools to find information. You MUST use them when the answer is not immediately known to you. Do not make up answers.
2. **MULTIMODAL PROCESSING:** When you encounter URLs or attachments:
- For image URLs (jpg, png, gif, etc.): Use the `describe_image` tool
- For YouTube URLs: Use the `process_youtube_video` tool
- For audio files (mp3, wav, etc.): Use the `process_audio_file` tool
- For other content: Use appropriate search tools
3. **AVAILABLE TOOLS:** Here is the exact list of tools you have access to:
{tools}
4. **REASONING:** Think step-by-step. First, analyze the user's question and any attachments. Second, decide which tools are appropriate. Third, call the tools with correct parameters. Finally, synthesize the results.
5. **URL DETECTION:** Look for URLs in the user's message, especially in brackets like [Attachment URL: ...]. Process these appropriately.
6. **FINAL ANSWER FORMAT:** Your final response MUST strictly follow this format:
`FINAL ANSWER: [Your comprehensive answer incorporating all tool results]`
"""
)
# ----------------------------------------------------------
# Section 4: Factory Function for Agent Executor
# ----------------------------------------------------------
def create_agent_executor(provider: str = "groq"):
"""
Factory function to create and compile the LangGraph agent executor.
"""
print(f"Initializing agent with provider: {provider}")
# Step 1: Build LLMs - Use Google for vision capabilities
if provider == "google":
main_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
elif provider == "groq":
main_llm = ChatGroq(model_name="meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0)
# Use Google for vision since Groq's vision support may be limited
main_llm = ChatGroq(model_name="meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0)
elif provider == "huggingface":
main_llm = ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.1))
vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
else:
raise ValueError("Invalid provider selected")
# Step 2: Build Retriever
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
if FAISS_CACHE.exists():
with open(FAISS_CACHE, "rb") as f: vector_store = pickle.load(f)
else:
if JSONL_PATH.exists():
docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))]
vector_store = FAISS.from_documents(docs, embeddings)
with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f)
else:
# Create empty vector store if no metadata file exists
docs = [Document(page_content="Sample document", metadata={"source": "sample"})]
vector_store = FAISS.from_documents(docs, embeddings)
retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
# Step 3: Create the final list of tools
tools_list = [
python_repl,
Tool(name="describe_image", func=functools.partial(describe_image_func, vision_llm_instance=vision_llm), description="Describes an image from a local file path or a URL. Use this for any image files or image URLs."),
process_youtube_video,
process_audio_file,
Tool(name="web_search", func=functools.partial(web_search_func, cache_func=cached_get), description="Performs a web search using Tavily."),
Tool(name="wiki_search", func=functools.partial(wiki_search_func, cache_func=cached_get), description="Searches Wikipedia."),
Tool(name="arxiv_search", func=functools.partial(arxiv_search_func, cache_func=cached_get), description="Searches Arxiv for scientific papers."),
create_retriever_tool(retriever=retriever, name="retrieve_examples", description="Retrieve solved questions similar to the user's query."),
]
# Step 4: Format the tool list into a string for the prompt
tool_definitions = "\n".join([f"- `{tool.name}`: {tool.description}" for tool in tools_list])
final_system_prompt = SYSTEM_PROMPT_TEMPLATE.format(tools=tool_definitions)
llm_with_tools = main_llm.bind_tools(tools_list)
# Step 5: Define Graph Nodes
def retriever_node(state: MessagesState):
user_query = state["messages"][-1].content
docs = retriever.invoke(user_query)
messages = [SystemMessage(content=final_system_prompt)]
if docs:
example_text = "\n\n---\n\n".join(d.page_content for d in docs)
messages.append(AIMessage(content=f"I have found {len(docs)} similar solved examples:\n\n{example_text}", name="ExampleRetriever"))
messages.extend(state["messages"])
return {"messages": messages}
def assistant_node(state: MessagesState):
result = llm_with_tools.invoke(state["messages"])
return {"messages": [result]}
# Step 6: Build Graph
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever_node)
builder.add_node("assistant", assistant_node)
builder.add_node("tools", ToolNode(tools_list))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges("assistant", tools_condition, {"tools": "tools", "__end__": "__end__"})
builder.add_edge("tools", "assistant")
agent_executor = builder.compile()
print("Agent Executor created successfully.")
return agent_executor