mcp-stock-math / backend.py
Entz's picture
Upload 8 files
effc62f verified
raw
history blame
8.9 kB
# backend.py
"""
Backend module for MCP Agent
Handles all the MCP server connections, LLM setup, and agent logic
"""
import sys
import os
import re
import asyncio
from dotenv import load_dotenv
from typing import Optional, Dict, List, Any
from pathlib import Path # already imported
here = Path(__file__).parent.resolve()
################### --- Auth setup --- ###################
##########################################################
HF = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
if not HF:
print("WARNING: HF_TOKEN not set. The app will start, but model calls may fail.")
else:
os.environ["HF_TOKEN"] = HF
os.environ["HUGGINGFACEHUB_API_TOKEN"] = HF
try:
from huggingface_hub import login
login(token=HF)
except Exception:
pass
# --- LangChain / MCP ---
from langgraph.prebuilt import create_react_agent
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain_core.messages import HumanMessage
# First choice, then free/tiny fallbacks
CANDIDATE_MODELS = [
os.getenv("HF_MODEL_ID", "Qwen/Qwen2.5-7B-Instruct"),
"HuggingFaceTB/SmolLM3-3B-Instruct",
"Qwen/Qwen2.5-1.5B-Instruct",
"microsoft/Phi-3-mini-4k-instruct",
]
SYSTEM_PROMPT = (
"You are an AI assistant with tools.\n"
"- Use the arithmetic tools (`add`, `minus`, `multiply`, `divide`) for arithmetic or multi-step calculations.\n"
"- Use the stock tools (`get_stock_price`, `get_market_summary`, `get_company_news`) for financial/market queries.\n"
"- Otherwise, answer directly with your own knowledge.\n"
"Be concise and accurate. Only call tools when they clearly help."
)
################### --- ROUTER helpers --- ###################
##############################################################
# Detect stock/financial intent
def is_stock_query(q: str) -> bool:
"""Check if query is about stocks, markets, or financial data."""
stock_patterns = [
r"\b(stock|share|price|ticker|market|nasdaq|dow|s&p|spy|qqq)\b",
r"\b(AAPL|GOOGL|MSFT|TSLA|AMZN|META|NVDA|AMD)\b", # Common tickers
r"\$[A-Z]{1,5}\b", # $SYMBOL format
r"\b(trading|invest|portfolio|earnings|dividend)\b",
r"\b(bull|bear|rally|crash|volatility)\b",
]
return any(re.search(pattern, q, re.I) for pattern in stock_patterns)
# Extract ticker symbol from query
def extract_ticker(q: str) -> str:
"""Extract stock ticker from query."""
# Check for $SYMBOL format first
dollar_match = re.search(r"\$([A-Z]{1,5})\b", q, re.I)
if dollar_match:
return dollar_match.group(1).upper()
# Check for common patterns like "price of AAPL" or "AAPL stock"
patterns = [
r"(?:price of|stock price of|quote for)\s+([A-Z]{1,5})\b",
r"\b([A-Z]{1,5})\s+(?:stock|share|price|quote)",
r"(?:what is|what's|get)\s+([A-Z]{1,5})\b",
]
for pattern in patterns:
match = re.search(pattern, q, re.I)
if match:
return match.group(1).upper()
# Look for standalone uppercase tickers
words = q.split()
for word in words:
clean_word = word.strip(".,!?")
if 2 <= len(clean_word) <= 5 and clean_word.isupper():
return clean_word
return None
# Check if asking for market summary
def wants_market_summary(q: str) -> bool:
"""Check if user wants overall market summary."""
patterns = [
r"\bmarket\s+(?:summary|overview|today|status)\b",
r"\bhow(?:'s| is) the market\b",
r"\b(?:dow|nasdaq|s&p)\s+(?:today|now)\b",
r"\bmarket indices\b",
]
return any(re.search(pattern, q, re.I) for pattern in patterns)
# Check if asking for news
def wants_news(q: str) -> bool:
"""Check if user wants company news."""
return bool(re.search(r"\b(news|headline|announcement|update)\b", q, re.I))
def build_tool_map(tools):
mp = {t.name: t for t in tools}
return mp
def find_tool(name: str, tool_map: dict):
name = name.lower()
for k, t in tool_map.items():
kl = k.lower()
if kl == name or kl.endswith("/" + name):
return t
return None
async def build_chat_llm_with_fallback():
"""
Try each candidate model. For each:
- create HuggingFaceEndpoint + ChatHuggingFace
- do a tiny 'ping' with a proper LC message to trigger routing
On 402/Payment Required (or other errors), fall through to next.
"""
last_err = None
for mid in CANDIDATE_MODELS:
try:
llm = HuggingFaceEndpoint(
repo_id=mid,
huggingfacehub_api_token=HF,
temperature=0.1,
max_new_tokens=256, # Increased for better responses
)
model = ChatHuggingFace(llm=llm)
# PROBE with a valid message type
_ = await model.ainvoke([HumanMessage(content="ping")])
print(f"[LLM] Using: {mid}")
return model
except Exception as e:
msg = str(e)
if "402" in msg or "Payment Required" in msg:
print(f"[LLM] {mid} requires payment; trying next...")
last_err = e
continue
print(f"[LLM] {mid} error: {e}; trying next...")
last_err = e
continue
raise RuntimeError(f"Could not initialize any candidate model. Last error: {last_err}")
# NEW: Class to manage the MCP Agent (moved from main function)
class MCPAgent:
def __init__(self):
self.client = None
self.agent = None
self.tool_map = None
self.tools = None
self.model = None
self.initialized = False
async def initialize(self):
"""Initialize the MCP client and agent"""
if self.initialized:
return
# Start the Stock server separately first: `python stockserver.py`
self.client = MultiServerMCPClient({
"arithmetic": {
"command": sys.executable,
"args": [str(here / "arithmetic_server.py")],
"transport": "stdio",
},
"stocks": {
"command": sys.executable,
"args": [str(here / "stock_server.py")],
"transport": "stdio",
},
}
)
# 1. MCP client + tools
self.tools = await self.client.get_tools()
self.tool_map = build_tool_map(self.tools)
# 2. Build LLM with auto-fallback
self.model = await build_chat_llm_with_fallback()
# Build the ReAct agent with MCP tools
self.agent = create_react_agent(self.model, self.tools)
self.initialized = True
return list(self.tool_map.keys()) # Return available tools
async def process_message(self, user_text: str, history: List[Dict]) -> str:
"""Process a single message with the agent"""
if not self.initialized:
await self.initialize()
# Try direct stock tool routing first
if is_stock_query(user_text):
if wants_market_summary(user_text):
market_tool = find_tool("get_market_summary", self.tool_map)
if market_tool:
return await market_tool.ainvoke({})
elif wants_news(user_text):
ticker = extract_ticker(user_text)
if ticker:
news_tool = find_tool("get_company_news", self.tool_map)
if news_tool:
return await news_tool.ainvoke({"symbol": ticker, "limit": 3})
else:
ticker = extract_ticker(user_text)
if ticker:
price_tool = find_tool("get_stock_price", self.tool_map)
if price_tool:
return await price_tool.ainvoke({"symbol": ticker})
# Fall back to agent for everything else
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history + [
{"role": "user", "content": user_text}
]
result = await self.agent.ainvoke({"messages": messages})
return result["messages"][-1].content
async def cleanup(self):
"""Clean up resources"""
if self.client:
close = getattr(self.client, "close", None)
if callable(close):
res = close()
if asyncio.iscoroutine(res):
await res
# NEW: Singleton instance for the agent
_agent_instance = None
def get_agent() -> MCPAgent:
"""Get or create the singleton agent instance"""
global _agent_instance
if _agent_instance is None:
_agent_instance = MCPAgent()
return _agent_instance