File size: 9,264 Bytes
4facc97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f97da2b
4facc97
 
 
f97da2b
 
 
4facc97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f97da2b
4facc97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f97da2b
4facc97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f97da2b
 
4facc97
 
 
 
 
 
 
 
 
 
 
 
 
 
f97da2b
68aa226
4facc97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
'''Classes to connect to Anthropic inference endpoint'''

import abc
from typing import Dict, List, Any, Optional
import anthropic
from client.mcp_client import MCPClientWrapper, ToolDef, ToolParameter, ToolInvocationResult

DEFAULT_ANTHROPIC_MODEL = 'claude-3-haiku-20240307'

# Type mapping from Python/MCP types to JSON Schema types
TYPE_MAPPING = {
    'int': 'integer',
    'bool': 'boolean',
    'str': 'string',
    'float': 'number',
    'list': 'array',
    'dict': 'object',
    'boolean': 'boolean',
    'string': 'string',
    'integer': 'integer',
    'number': 'number',
    'array': 'array',
    'object': 'object'
}


class LLMBridge(abc.ABC):
    '''Abstract base class for LLM bridge implementations.'''

    def __init__(self, mcp_client: MCPClientWrapper):
        '''Initialize the LLM bridge with an MCPClient instance.

        Args:
            mcp_client: An initialized MCPClient instance
        '''
        self.mcp_client = mcp_client
        self.tools = None


    async def fetch_tools(self) -> List[ToolDef]:
        '''Fetch available tools from the MCP endpoint.

        
        Returns:
            List of ToolDef objects
        '''
        self.tools = await self.mcp_client.list_tools()
        return self.tools


    @abc.abstractmethod
    async def format_tools(self, tools: List[ToolDef]) -> Any:
        '''Format tools for the specific LLM provider.

        Args:
            tools: List of ToolDef objects
            
        Returns:
            Formatted tools in the LLM-specific format
        '''
        pass


    @abc.abstractmethod
    async def submit_query(self, system_prompt: str, query: List[Dict], formatted_tools: Any) -> Dict[str, Any]:
        '''Submit a query to the LLM with the formatted tools.

        Args:
            query: User query string            # messages=[
            #     {'role': 'user', 'content': query}
            # ],
            formatted_tools: Tools in the LLM-specific format
            
        Returns:
            LLM response
        '''
        pass


    @abc.abstractmethod
    async def parse_tool_call(self, llm_response: Any) -> Optional[Dict[str, Any]]:
        '''Parse the LLM response to extract tool calls.

        Args:
            llm_response: Response from the LLM
            
        Returns:
            Dictionary with tool name and parameters, or None if no tool call
        '''
        pass


    async def execute_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
        '''Execute a tool with the given parameters.

        Args:
            tool_name: Name of the tool to invoke
            kwargs: Dictionary of parameters to pass to the tool
            
        Returns:
            ToolInvocationResult containing the tool's response
        '''
        return await self.mcp_client.invoke_tool(tool_name, kwargs)


    async def process_query(self, system_prompt: str, query: List[Dict]) -> Dict[str, Any]:
        '''Process a user query through the LLM and execute any tool calls.

        This method handles the full flow:
        1. Fetch tools if not already fetched
        2. Format tools for the LLM
        3. Submit query to LLM
        4. Parse tool calls from LLM response
        5. Execute tool if needed
        
        Args:
            query: User query string
            
        Returns:
            Dictionary containing the LLM response, tool call, and tool result
        '''
        # 1. Fetch tools if not already fetched
        if self.tools is None:
            await self.fetch_tools()

        # 2. Format tools for the LLM
        formatted_tools = await self.format_tools(self.tools)

        # 3. Submit query to LLM
        llm_response = await self.submit_query(system_prompt, query, formatted_tools)

        # 4. Parse tool calls from LLM response
        tool_call = await self.parse_tool_call(llm_response)

        result = {
            'llm_response': llm_response,
            'tool_call': tool_call,
            'tool_result': None
        }

        # 5. Execute tool if needed
        if tool_call:
            tool_name = tool_call.get('name')
            kwargs = tool_call.get('parameters', {})
            tool_result = await self.execute_tool(tool_name, kwargs)
            result['tool_result'] = tool_result

        return result


class AnthropicBridge(LLMBridge):
    '''Anthropic-specific implementation of the LLM Bridge.'''

    def __init__(self, mcp_client, api_key, model=DEFAULT_ANTHROPIC_MODEL): # Use imported default
        '''Initialize Anthropic bridge with API key and model.
        
        Args:
            mcp_client: An initialized MCPClient instance
            api_key: Anthropic API key
            model: Anthropic model to use (default: from models.py)
        '''
        super().__init__(mcp_client)
        self.llm_client = anthropic.Anthropic(api_key=api_key)
        self.model = model


    async def format_tools(self, tools: List[ToolDef]) -> List[Dict[str, Any]]:
        '''Format tools for Anthropic.
        
        Args:
            tools: List of ToolDef objects
            
        Returns:
            List of tools in Anthropic format
        '''
        return to_anthropic_format(tools)


    async def submit_query(
            self,
            system_prompt: str,
            query: List[Dict],
            formatted_tools: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        '''Submit a query to Anthropic with the formatted tools.
        
        Args:
            query: User query string
            formatted_tools: Tools in Anthropic format
            
        Returns:
            Anthropic API response
        '''
        response = self.llm_client.messages.create(
            model=self.model,
            max_tokens=4096,
            system=system_prompt,
            messages=query,
            tools=formatted_tools
        )

        return response


    async def parse_tool_call(self, llm_response: Any) -> Optional[Dict[str, Any]]:
        '''Parse the Anthropic response to extract tool calls.
        
        Args:
            llm_response: Response from Anthropic
            
        Returns:
            Dictionary with tool name and parameters, or None if no tool call
        '''
        for content in llm_response.content:
            if content.type == 'tool_use':
                return {
                    'name': content.name,
                    'parameters': content.input
                }

        return None



def to_anthropic_format(tools: List[ToolDef]) -> List[Dict[str, Any]]:
    '''Convert ToolDef objects to Anthropic tool format.
    
    Args:
        tools: List of ToolDef objects to convert
        
    Returns:
        List of dictionaries in Anthropic tool format
    '''

    anthropic_tools = []
    for tool in tools:
        anthropic_tool = {
            'name': tool.name,
            'description': tool.description,
            'input_schema': {
                'type': 'object',
                'properties': {},
                'required': []
            }
        }

        # Add properties
        for param in tool.parameters:
            # Map the type or use the original if no mapping exists
            schema_type = TYPE_MAPPING.get(param.parameter_type, param.parameter_type)

            param_schema = {
                'type': schema_type,  # Use mapped type
                'description': param.description
            }

            # For arrays, we need to specify the items type
            if schema_type == 'array':
                item_type = _infer_array_item_type(param)
                param_schema['items'] = {'type': item_type}

            anthropic_tool['input_schema']['properties'][param.name] = param_schema

            # Add default value if provided
            if param.default is not None:
                anthropic_tool['input_schema']['properties'][param.name]['default'] = param.default

            # Add to required list if required
            if param.required:
                anthropic_tool['input_schema']['required'].append(param.name)

        anthropic_tools.append(anthropic_tool)
    return anthropic_tools


def _infer_array_item_type(param: ToolParameter) -> str:
    '''Infer the item type for an array parameter based on its name and description.
    
    Args:
        param: The ToolParameter object
        
    Returns:
        The inferred JSON Schema type for array items
    '''
    # Default to string items
    item_type = 'string'

    # Check if parameter name contains hints about item type
    param_name_lower = param.name.lower()
    if any(hint in param_name_lower for hint in ['language', 'code', 'tag', 'name', 'id']):
        item_type = 'string'
    elif any(hint in param_name_lower for hint in ['number', 'count', 'amount', 'index']):
        item_type = 'integer'

    # Also check the description for hints
    if param.description:
        desc_lower = param.description.lower()
        if 'string' in desc_lower or 'text' in desc_lower or 'language' in desc_lower:
            item_type = 'string'
        elif 'number' in desc_lower or 'integer' in desc_lower or 'int' in desc_lower:
            item_type = 'integer'

    return item_type