Spaces:
Sleeping
Sleeping
# backend.py | |
""" | |
Backend module for MCP Agent | |
Handles all the MCP server connections, LLM setup, and agent logic | |
""" | |
import sys | |
import os | |
import re | |
import asyncio | |
from dotenv import load_dotenv | |
from typing import Optional, Dict, List, Any | |
from pathlib import Path # already imported | |
here = Path(__file__).parent.resolve() | |
################### --- Auth setup --- ################### | |
########################################################## | |
HF = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
if not HF: | |
print("WARNING: HF_TOKEN not set. The app will start, but model calls may fail.") | |
else: | |
os.environ["HF_TOKEN"] = HF | |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = HF | |
try: | |
from huggingface_hub import login | |
login(token=HF) | |
except Exception: | |
pass | |
# --- LangChain / MCP --- | |
from langgraph.prebuilt import create_react_agent | |
from langchain_mcp_adapters.client import MultiServerMCPClient | |
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace | |
from langchain_core.messages import HumanMessage | |
# First choice, then free/tiny fallbacks | |
CANDIDATE_MODELS = [ | |
os.getenv("HF_MODEL_ID", "Qwen/Qwen2.5-7B-Instruct"), | |
"HuggingFaceTB/SmolLM3-3B-Instruct", | |
"Qwen/Qwen2.5-1.5B-Instruct", | |
"microsoft/Phi-3-mini-4k-instruct", | |
] | |
SYSTEM_PROMPT = ( | |
"You are an AI assistant with tools.\n" | |
"- Use the arithmetic tools (`add`, `minus`, `multiply`, `divide`) for arithmetic or multi-step calculations.\n" | |
"- Use the stock tools (`get_stock_price`, `get_market_summary`, `get_company_news`) for financial/market queries.\n" | |
"- Otherwise, answer directly with your own knowledge.\n" | |
"Be concise and accurate. Only call tools when they clearly help." | |
) | |
################### --- ROUTER helpers --- ################### | |
############################################################## | |
# Detect stock/financial intent | |
def is_stock_query(q: str) -> bool: | |
"""Check if query is about stocks, markets, or financial data.""" | |
stock_patterns = [ | |
r"\b(stock|share|price|ticker|market|nasdaq|dow|s&p|spy|qqq)\b", | |
r"\b(AAPL|GOOGL|MSFT|TSLA|AMZN|META|NVDA|AMD)\b", # Common tickers | |
r"\$[A-Z]{1,5}\b", # $SYMBOL format | |
r"\b(trading|invest|portfolio|earnings|dividend)\b", | |
r"\b(bull|bear|rally|crash|volatility)\b", | |
] | |
return any(re.search(pattern, q, re.I) for pattern in stock_patterns) | |
# Extract ticker symbol from query | |
def extract_ticker(q: str) -> str: | |
"""Extract stock ticker from query.""" | |
# Check for $SYMBOL format first | |
dollar_match = re.search(r"\$([A-Z]{1,5})\b", q, re.I) | |
if dollar_match: | |
return dollar_match.group(1).upper() | |
# Check for common patterns like "price of AAPL" or "AAPL stock" | |
patterns = [ | |
r"(?:price of|stock price of|quote for)\s+([A-Z]{1,5})\b", | |
r"\b([A-Z]{1,5})\s+(?:stock|share|price|quote)", | |
r"(?:what is|what's|get)\s+([A-Z]{1,5})\b", | |
] | |
for pattern in patterns: | |
match = re.search(pattern, q, re.I) | |
if match: | |
return match.group(1).upper() | |
# Look for standalone uppercase tickers | |
words = q.split() | |
for word in words: | |
clean_word = word.strip(".,!?") | |
if 2 <= len(clean_word) <= 5 and clean_word.isupper(): | |
return clean_word | |
return None | |
# Check if asking for market summary | |
def wants_market_summary(q: str) -> bool: | |
"""Check if user wants overall market summary.""" | |
patterns = [ | |
r"\bmarket\s+(?:summary|overview|today|status)\b", | |
r"\bhow(?:'s| is) the market\b", | |
r"\b(?:dow|nasdaq|s&p)\s+(?:today|now)\b", | |
r"\bmarket indices\b", | |
] | |
return any(re.search(pattern, q, re.I) for pattern in patterns) | |
# Check if asking for news | |
def wants_news(q: str) -> bool: | |
"""Check if user wants company news.""" | |
return bool(re.search(r"\b(news|headline|announcement|update)\b", q, re.I)) | |
def build_tool_map(tools): | |
mp = {t.name: t for t in tools} | |
return mp | |
def find_tool(name: str, tool_map: dict): | |
name = name.lower() | |
for k, t in tool_map.items(): | |
kl = k.lower() | |
if kl == name or kl.endswith("/" + name): | |
return t | |
return None | |
async def build_chat_llm_with_fallback(): | |
""" | |
Try each candidate model. For each: | |
- create HuggingFaceEndpoint + ChatHuggingFace | |
- do a tiny 'ping' with a proper LC message to trigger routing | |
On 402/Payment Required (or other errors), fall through to next. | |
""" | |
last_err = None | |
for mid in CANDIDATE_MODELS: | |
try: | |
llm = HuggingFaceEndpoint( | |
repo_id=mid, | |
huggingfacehub_api_token=HF, | |
temperature=0.1, | |
max_new_tokens=256, # Increased for better responses | |
) | |
model = ChatHuggingFace(llm=llm) | |
# PROBE with a valid message type | |
_ = await model.ainvoke([HumanMessage(content="ping")]) | |
print(f"[LLM] Using: {mid}") | |
return model | |
except Exception as e: | |
msg = str(e) | |
if "402" in msg or "Payment Required" in msg: | |
print(f"[LLM] {mid} requires payment; trying next...") | |
last_err = e | |
continue | |
print(f"[LLM] {mid} error: {e}; trying next...") | |
last_err = e | |
continue | |
raise RuntimeError(f"Could not initialize any candidate model. Last error: {last_err}") | |
# NEW: Class to manage the MCP Agent (moved from main function) | |
class MCPAgent: | |
def __init__(self): | |
self.client = None | |
self.agent = None | |
self.tool_map = None | |
self.tools = None | |
self.model = None | |
self.initialized = False | |
async def initialize(self): | |
"""Initialize the MCP client and agent""" | |
if self.initialized: | |
return | |
# Start the Stock server separately first: `python stockserver.py` | |
self.client = MultiServerMCPClient({ | |
"arithmetic": { | |
"command": sys.executable, | |
"args": [str(here / "arithmetic_server.py")], | |
"transport": "stdio", | |
}, | |
"stocks": { | |
"command": sys.executable, | |
"args": [str(here / "stock_server.py")], | |
"transport": "stdio", | |
}, | |
} | |
) | |
# 1. MCP client + tools | |
self.tools = await self.client.get_tools() | |
self.tool_map = build_tool_map(self.tools) | |
# 2. Build LLM with auto-fallback | |
self.model = await build_chat_llm_with_fallback() | |
# Build the ReAct agent with MCP tools | |
self.agent = create_react_agent(self.model, self.tools) | |
self.initialized = True | |
return list(self.tool_map.keys()) # Return available tools | |
async def process_message(self, user_text: str, history: List[Dict]) -> str: | |
"""Process a single message with the agent""" | |
if not self.initialized: | |
await self.initialize() | |
# Try direct stock tool routing first | |
if is_stock_query(user_text): | |
if wants_market_summary(user_text): | |
market_tool = find_tool("get_market_summary", self.tool_map) | |
if market_tool: | |
return await market_tool.ainvoke({}) | |
elif wants_news(user_text): | |
ticker = extract_ticker(user_text) | |
if ticker: | |
news_tool = find_tool("get_company_news", self.tool_map) | |
if news_tool: | |
return await news_tool.ainvoke({"symbol": ticker, "limit": 3}) | |
else: | |
ticker = extract_ticker(user_text) | |
if ticker: | |
price_tool = find_tool("get_stock_price", self.tool_map) | |
if price_tool: | |
return await price_tool.ainvoke({"symbol": ticker}) | |
# Fall back to agent for everything else | |
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history + [ | |
{"role": "user", "content": user_text} | |
] | |
result = await self.agent.ainvoke({"messages": messages}) | |
return result["messages"][-1].content | |
async def cleanup(self): | |
"""Clean up resources""" | |
if self.client: | |
close = getattr(self.client, "close", None) | |
if callable(close): | |
res = close() | |
if asyncio.iscoroutine(res): | |
await res | |
# NEW: Singleton instance for the agent | |
_agent_instance = None | |
def get_agent() -> MCPAgent: | |
"""Get or create the singleton agent instance""" | |
global _agent_instance | |
if _agent_instance is None: | |
_agent_instance = MCPAgent() | |
return _agent_instance | |