File size: 8,901 Bytes
effc62f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d072c5
 
 
effc62f
 
3d072c5
 
 
effc62f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# backend.py
"""
Backend module for MCP Agent
Handles all the MCP server connections, LLM setup, and agent logic
"""
import sys
import os
import re
import asyncio
from dotenv import load_dotenv
from typing import Optional, Dict, List, Any

from pathlib import Path  # already imported
here = Path(__file__).parent.resolve()



################### --- Auth setup --- ###################
##########################################################
HF = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
if not HF:
    print("WARNING: HF_TOKEN not set. The app will start, but model calls may fail.")
else:
    os.environ["HF_TOKEN"] = HF
    os.environ["HUGGINGFACEHUB_API_TOKEN"] = HF
    try:
        from huggingface_hub import login
        login(token=HF)
    except Exception:
        pass


# --- LangChain / MCP ---
from langgraph.prebuilt import create_react_agent
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain_core.messages import HumanMessage

# First choice, then free/tiny fallbacks
CANDIDATE_MODELS = [
    os.getenv("HF_MODEL_ID", "Qwen/Qwen2.5-7B-Instruct"),
    "HuggingFaceTB/SmolLM3-3B-Instruct",
    "Qwen/Qwen2.5-1.5B-Instruct",
    "microsoft/Phi-3-mini-4k-instruct",
]

SYSTEM_PROMPT = (
    "You are an AI assistant with tools.\n"
    "- Use the arithmetic tools (`add`, `minus`, `multiply`, `divide`) for arithmetic or multi-step calculations.\n"
    "- Use the stock tools (`get_stock_price`, `get_market_summary`, `get_company_news`) for financial/market queries.\n"
    "- Otherwise, answer directly with your own knowledge.\n"
    "Be concise and accurate. Only call tools when they clearly help."
)

################### --- ROUTER helpers --- ###################
##############################################################

# Detect stock/financial intent
def is_stock_query(q: str) -> bool:
    """Check if query is about stocks, markets, or financial data."""
    stock_patterns = [
        r"\b(stock|share|price|ticker|market|nasdaq|dow|s&p|spy|qqq)\b",
        r"\b(AAPL|GOOGL|MSFT|TSLA|AMZN|META|NVDA|AMD)\b",  # Common tickers
        r"\$[A-Z]{1,5}\b",  # $SYMBOL format
        r"\b(trading|invest|portfolio|earnings|dividend)\b",
        r"\b(bull|bear|rally|crash|volatility)\b",
    ]
    return any(re.search(pattern, q, re.I) for pattern in stock_patterns)

# Extract ticker symbol from query
def extract_ticker(q: str) -> str:
    """Extract stock ticker from query."""
    # Check for $SYMBOL format first
    dollar_match = re.search(r"\$([A-Z]{1,5})\b", q, re.I)
    if dollar_match:
        return dollar_match.group(1).upper()
    
    # Check for common patterns like "price of AAPL" or "AAPL stock"
    patterns = [
        r"(?:price of|stock price of|quote for)\s+([A-Z]{1,5})\b",
        r"\b([A-Z]{1,5})\s+(?:stock|share|price|quote)",
        r"(?:what is|what's|get)\s+([A-Z]{1,5})\b",
    ]
    
    for pattern in patterns:
        match = re.search(pattern, q, re.I)
        if match:
            return match.group(1).upper()
    
    # Look for standalone uppercase tickers
    words = q.split()
    for word in words:
        clean_word = word.strip(".,!?")
        if 2 <= len(clean_word) <= 5 and clean_word.isupper():
            return clean_word
    
    return None

# Check if asking for market summary
def wants_market_summary(q: str) -> bool:
    """Check if user wants overall market summary."""
    patterns = [
        r"\bmarket\s+(?:summary|overview|today|status)\b",
        r"\bhow(?:'s| is) the market\b",
        r"\b(?:dow|nasdaq|s&p)\s+(?:today|now)\b",
        r"\bmarket indices\b",
    ]
    return any(re.search(pattern, q, re.I) for pattern in patterns)

# Check if asking for news
def wants_news(q: str) -> bool:
    """Check if user wants company news."""
    return bool(re.search(r"\b(news|headline|announcement|update)\b", q, re.I))

def build_tool_map(tools):
    mp = {t.name: t for t in tools}
    return mp

def find_tool(name: str, tool_map: dict):
    name = name.lower()
    for k, t in tool_map.items():
        kl = k.lower()
        if kl == name or kl.endswith("/" + name):
            return t
    return None

async def build_chat_llm_with_fallback():
    """
    Try each candidate model. For each:
        - create HuggingFaceEndpoint + ChatHuggingFace
        - do a tiny 'ping' with a proper LC message to trigger routing
    On 402/Payment Required (or other errors), fall through to next.
    """
    last_err = None
    for mid in CANDIDATE_MODELS:
        try:
            llm = HuggingFaceEndpoint(
                repo_id=mid,
                huggingfacehub_api_token=HF,
                temperature=0.1,
                max_new_tokens=256,  # Increased for better responses
            )
            model = ChatHuggingFace(llm=llm)
            # PROBE with a valid message type
            _ = await model.ainvoke([HumanMessage(content="ping")])
            print(f"[LLM] Using: {mid}")
            return model
        except Exception as e:
            msg = str(e)
            if "402" in msg or "Payment Required" in msg:
                print(f"[LLM] {mid} requires payment; trying next...")
                last_err = e
                continue
            print(f"[LLM] {mid} error: {e}; trying next...")
            last_err = e
            continue
    raise RuntimeError(f"Could not initialize any candidate model. Last error: {last_err}")

# NEW: Class to manage the MCP Agent (moved from main function)
class MCPAgent:
    def __init__(self):
        self.client = None
        self.agent = None
        self.tool_map = None
        self.tools = None
        self.model = None
        self.initialized = False
        
    async def initialize(self):
        """Initialize the MCP client and agent"""
        if self.initialized:
            return
            
        # Start the Stock server separately first: `python stockserver.py`
        self.client = MultiServerMCPClient({
            "arithmetic": {
                "command": sys.executable,
                "args": [str(here / "arithmetic_server.py")],
                "transport": "stdio",
            },
            "stocks": {
                "command": sys.executable,
                "args": [str(here / "stock_server.py")],
                "transport": "stdio",
            },
        }
    )
        
        # 1. MCP client + tools
        self.tools = await self.client.get_tools()
        self.tool_map = build_tool_map(self.tools)
        
        # 2. Build LLM with auto-fallback
        self.model = await build_chat_llm_with_fallback()
        
        # Build the ReAct agent with MCP tools
        self.agent = create_react_agent(self.model, self.tools)
        
        self.initialized = True
        return list(self.tool_map.keys())  # Return available tools
    
    async def process_message(self, user_text: str, history: List[Dict]) -> str:
        """Process a single message with the agent"""
        if not self.initialized:
            await self.initialize()
        
        # Try direct stock tool routing first
        if is_stock_query(user_text):
            if wants_market_summary(user_text):
                market_tool = find_tool("get_market_summary", self.tool_map)
                if market_tool:
                    return await market_tool.ainvoke({})
            
            elif wants_news(user_text):
                ticker = extract_ticker(user_text)
                if ticker:
                    news_tool = find_tool("get_company_news", self.tool_map)
                    if news_tool:
                        return await news_tool.ainvoke({"symbol": ticker, "limit": 3})
            
            else:
                ticker = extract_ticker(user_text)
                if ticker:
                    price_tool = find_tool("get_stock_price", self.tool_map)
                    if price_tool:
                        return await price_tool.ainvoke({"symbol": ticker})
        
        # Fall back to agent for everything else
        messages = [{"role": "system", "content": SYSTEM_PROMPT}] + history + [
            {"role": "user", "content": user_text}
        ]
        result = await self.agent.ainvoke({"messages": messages})
        return result["messages"][-1].content
    
    async def cleanup(self):
        """Clean up resources"""
        if self.client:
            close = getattr(self.client, "close", None)
            if callable(close):
                res = close()
                if asyncio.iscoroutine(res):
                    await res






# NEW: Singleton instance for the agent
_agent_instance = None

def get_agent() -> MCPAgent:
    """Get or create the singleton agent instance"""
    global _agent_instance
    if _agent_instance is None:
        _agent_instance = MCPAgent()
    return _agent_instance