Spaces:
Sleeping
Sleeping
File size: 11,518 Bytes
e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 6a1d4f7 e5bb694 997896a 6a1d4f7 e5bb694 6a1d4f7 e5bb694 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
# ----------------------------------------------------------
# Section 0: Imports
# ----------------------------------------------------------
import json
import os
import pickle
import re
import subprocess
import textwrap
import base64
from datetime import datetime, timedelta
from io import BytesIO
from pathlib import Path
from typing import List
# Third-party libraries
import requests
from cachetools import TTLCache
from PIL import Image
# LangChain and associated libraries
from langchain.schema import Document
from langchain.tools.retriever import create_retriever_tool
from langchain_community.vectorstores import FAISS
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader # Added loaders
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
# Environment variable loading
from dotenv import load_dotenv
load_dotenv()
# ----------------------------------------------------------
# Section 1: Constants and Configuration
# ----------------------------------------------------------
JSONL_PATH = Path("metadata.jsonl")
FAISS_CACHE = Path("faiss_index.pkl")
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
RETRIEVER_K = 5 # Number of similar documents to retrieve
CACHE_TTL = 600 # Cache API calls for 10 minutes
# Global cache object for API calls
API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
# ----------------------------------------------------------
# Section 2: The Agent Class
# ----------------------------------------------------------
class MyAgent:
"""
Encapsulates the agent's state, including LLMs, retriever, and tools.
This class-based approach ensures clean management of dependencies.
"""
def __init__(self, provider: str = "google"):
"""
Initializes the agent, setting up LLMs and the FAISS retriever.
Args:
provider (str): The LLM provider to use ('google', 'groq', 'huggingface').
"""
print(f"Initializing agent with provider: {provider}")
self.llm = self._build_llm(provider)
self.vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0)
self.retriever = self._get_retriever()
def _get_retriever(self):
"""Builds or loads the FAISS retriever from a local cache."""
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
if FAISS_CACHE.exists():
print(f"Loading FAISS index from cache: {FAISS_CACHE}")
with open(FAISS_CACHE, "rb") as f:
vector_store = pickle.load(f)
else:
print("FAISS cache not found. Building new index from metadata.jsonl...")
if not JSONL_PATH.exists():
raise FileNotFoundError(f"{JSONL_PATH} not found. Cannot build vector store.")
docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))]
if not docs: raise ValueError("No documents found in metadata.jsonl.")
vector_store = FAISS.from_documents(docs, embeddings)
with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f)
print(f"FAISS index built and saved to cache: {FAISS_CACHE}")
return vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
def _build_llm(self, provider: str):
"""Helper to build the main text-based LLM based on the chosen provider."""
if provider == "google": return ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
elif provider == "groq": return ChatGroq(model_name="llama3-70b-8192", temperature=0)
elif provider == "huggingface": return ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="Qwen/Qwen2.5-Coder-32B-Instruct", temperature=0))
else: raise ValueError("Provider must be 'google', 'groq', or 'huggingface'")
def _cached_get(self, key: str, fetch_fn):
"""Helper for caching API calls."""
if key in API_CACHE: return API_CACHE[key]
val = fetch_fn()
API_CACHE[key] = val
return val
# --- Tool Definitions as Class Methods ---
@tool
def python_repl(self, code: str) -> str:
"""Executes a string of Python code and returns the stdout/stderr."""
code = textwrap.dedent(code).strip()
try:
result = subprocess.run(["python", "-c", code], capture_output=True, text=True, timeout=10, check=False)
if result.returncode == 0: return f"Execution successful.\nSTDOUT:\n```\n{result.stdout}\n```"
else: return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```"
except subprocess.TimeoutExpired: return "Execution timed out (>10s)."
@tool
def describe_image(self, image_source: str) -> str:
"""Describes an image from a local file path or a URL using Gemini vision."""
try:
if image_source.startswith("http"):
img = Image.open(BytesIO(requests.get(image_source, timeout=10).content))
else:
img = Image.open(image_source)
buffered = BytesIO()
img.convert("RGB").save(buffered, format="JPEG")
b64_string = base64.b64encode(buffered.getvalue()).decode()
msg = HumanMessage(content=[{"type": "text", "text": "Describe this image in detail."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}])
return self.vision_llm.invoke([msg]).content
except Exception as e: return f"Error processing image: {e}"
@tool
def web_search(self, query: str) -> str:
"""Performs a web search using Tavily and returns a compilation of results."""
key = f"web:{query}"
results = self._cached_get(key, lambda: TavilySearchResults(max_results=5).invoke(query))
return "\n\n---\n\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in results])
@tool
def wiki_search(self, query: str) -> str:
"""Searches Wikipedia and returns the top 2 results."""
key = f"wiki:{query}"
docs = self._cached_get(key, lambda: WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load())
return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\n\n{d.page_content}" for d in docs])
@tool
def arxiv_search(self, query: str) -> str:
"""Searches Arxiv for scientific papers and returns the top 2 results."""
key = f"arxiv:{query}"
docs = self._cached_get(key, lambda: ArxivLoader(query=query, load_max_docs=2).load())
return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs])
def get_tools(self) -> list:
"""Returns a list of all tools available to the agent."""
tools_list = [
self.python_repl,
self.describe_image,
self.web_search,
self.wiki_search,
self.arxiv_search,
]
retriever_tool = create_retriever_tool(
retriever=self.retriever,
name="retrieve_examples",
description="Retrieve solved questions and answers similar to the user's query.",
)
tools_list.append(retriever_tool)
return tools_list
# ----------------------------------------------------------
# Section 3: System Prompt
# ----------------------------------------------------------
SYSTEM_PROMPT = (
"""You are a helpful and expert assistant designed to answer questions accurately and concisely.
**Instructions:**
1. **Analyze the Question:** Carefully understand what is being asked.
2. **Use Tools:** You have a set of tools to find information. Use them logically.
3. **Synthesize the Answer:** Based on the information from the tools, formulate your final answer.
4. **Format the Output:** Your final response MUST be in the following format and nothing else:
FINAL ANSWER: [Your concise and accurate answer here]
If the `retrieve_examples` tool provides an answer to an identical question, use that answer. Otherwise, use your tools to find the correct answer for the current question.
"""
)
# ----------------------------------------------------------
# Section 4: Factory Function for Agent Executor
# ----------------------------------------------------------
def create_agent_executor(provider: str = "google"):
"""Factory function to create and compile the LangGraph agent executor."""
my_agent_instance = MyAgent(provider=provider)
tools_list = my_agent_instance.get_tools()
llm_with_tools = my_agent_instance.llm.bind_tools(tools_list)
def retriever_node(state: MessagesState):
"""First node: retrieves examples and prepends them to the message history."""
user_query = state["messages"][-1].content
docs = my_agent_instance.retriever.invoke(user_query)
messages = [SystemMessage(content=SYSTEM_PROMPT)]
if docs:
example_text = "\n\n---\n\n".join(d.page_content for d in docs)
example_msg = AIMessage(content=f"I have found {len(docs)} similar solved examples:\n\n{example_text}", name="ExampleRetriever")
messages.append(example_msg)
messages.extend(state["messages"])
return {"messages": messages}
def assistant_node(state: MessagesState):
"""Main assistant node: calls the LLM with the current state to decide the next action."""
result = llm_with_tools.invoke(state["messages"])
return {"messages": [result]}
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever_node)
builder.add_node("assistant", assistant_node)
builder.add_node("tools", ToolNode(tools_list))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges("assistant", tools_condition, {"tools": "tools", "__end__": "__end__"})
builder.add_edge("tools", "assistant")
agent_executor = builder.compile()
print("Agent Executor created successfully.")
return agent_executor
# ----------------------------------------------------------
# Section 5: Direct Execution Block for Testing
# ----------------------------------------------------------
if __name__ == "__main__":
"""direct testing of the agent's logic."""
print("--- Running Agent in Test Mode ---")
agent = create_agent_executor(provider="google")
question = "According to wikipedia, what is the main difference between a lama and an alpaca?"
print(f"\nTest Question: {question}\n\n--- Agent Thinking... ---\n")
for chunk in agent.stream({"messages": [("user", question)]}):
for key, value in chunk.items():
if value['messages']:
message = value['messages'][-1]
if message.content: print(f"--- Node: {key} ---\n{message.content}\n") |