davidgturner commited on
Commit
07ad0d5
·
1 Parent(s): f5bafc2

- changes for local model

Browse files
Files changed (5) hide show
  1. app.py +72 -26
  2. config.py +12 -0
  3. requirements.txt +7 -1
  4. test_agent.py +5 -4
  5. utils/local_model.py +177 -0
app.py CHANGED
@@ -1,44 +1,41 @@
1
  import os
2
  import gradio as gr
3
  import requests
4
- import inspect
5
  import pandas as pd
6
- import time
7
- import json
8
- import io
9
- import base64
10
- from typing import Dict, List, Union, Optional
11
- import re
12
- import sys
13
- from bs4 import BeautifulSoup
14
- from duckduckgo_search import DDGS
15
- import pytube
16
- from dateutil import parser
17
- try:
18
- from youtube_transcript_api import YouTubeTranscriptApi
19
- except ImportError:
20
- print("YouTube Transcript API not installed. Video transcription may be limited.")
21
-
22
- from smolagents import Tool, CodeAgent, InferenceClientModel
23
 
24
  # Import internal modules
25
  from config import (
26
- DEFAULT_API_URL, LLAMA_API_URL, HF_API_TOKEN, HEADERS,
27
- MAX_RETRIES, RETRY_DELAY
28
  )
29
  from tools.tool_manager import ToolManager
 
30
 
31
  class GaiaToolCallingAgent:
32
  """Tool-calling agent specifically designed for the GAIA system."""
33
 
34
- def __init__(self):
35
  print("GaiaToolCallingAgent initialized.")
36
  self.tool_manager = ToolManager()
37
  self.name = "tool_agent" # Add required name attribute for smolagents integration
38
  self.description = "A specialized agent that uses various tools to answer questions" # Required by smolagents
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def run(self, query: str) -> str:
41
  """Process a query and return a response using available tools."""
 
42
  tools = self.tool_manager.get_tools()
43
 
44
  # For each tool, try to get relevant information
@@ -47,6 +44,7 @@ class GaiaToolCallingAgent:
47
  for tool in tools:
48
  try:
49
  if self._should_use_tool(tool, query):
 
50
  result = tool.forward(query)
51
  if result:
52
  context_info.append(f"{tool.name} Results:\n{result}")
@@ -56,7 +54,29 @@ class GaiaToolCallingAgent:
56
  # Combine all context information
57
  full_context = "\n\n".join(context_info) if context_info else ""
58
 
59
- return full_context
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def __call__(self, query: str) -> str:
62
  """Make the agent callable so it can be used directly by CodeAgent."""
@@ -76,22 +96,47 @@ class GaiaToolCallingAgent:
76
  "gaia_retriever": ["gaia", "agent", "ai", "artificial intelligence"]
77
  }
78
 
 
 
 
 
79
  return any(pattern in query_lower for pattern in patterns.get(tool.name, []))
80
 
81
  def create_manager_agent() -> CodeAgent:
82
  """Create and configure the main GAIA agent."""
83
 
84
- # Initialize the managed tool-calling agent
85
- tool_agent = GaiaToolCallingAgent()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  # Create the manager agent
88
  manager_agent = CodeAgent(
89
- model=InferenceClientModel(),
90
  tools=[], # No direct tools for manager
91
  managed_agents=[tool_agent],
92
  additional_authorized_imports=[
93
  "json",
94
- "pandas",
95
  "numpy",
96
  "re",
97
  "requests",
@@ -102,6 +147,7 @@ def create_manager_agent() -> CodeAgent:
102
  max_steps=10
103
  )
104
 
 
105
  return manager_agent
106
 
107
  def create_agent():
 
1
  import os
2
  import gradio as gr
3
  import requests
 
4
  import pandas as pd
5
+ from smolagents import Tool, CodeAgent, Model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Import internal modules
8
  from config import (
9
+ DEFAULT_API_URL
 
10
  )
11
  from tools.tool_manager import ToolManager
12
+ from utils.local_model import LocalTransformersModel
13
 
14
  class GaiaToolCallingAgent:
15
  """Tool-calling agent specifically designed for the GAIA system."""
16
 
17
+ def __init__(self, local_model=None):
18
  print("GaiaToolCallingAgent initialized.")
19
  self.tool_manager = ToolManager()
20
  self.name = "tool_agent" # Add required name attribute for smolagents integration
21
  self.description = "A specialized agent that uses various tools to answer questions" # Required by smolagents
22
 
23
+ # Use local model if provided, or create a simpler one
24
+ self.local_model = local_model
25
+ if not self.local_model:
26
+ try:
27
+ from utils.local_model import LocalTransformersModel
28
+ self.local_model = LocalTransformersModel(
29
+ model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.6",
30
+ max_tokens=512
31
+ )
32
+ except Exception as e:
33
+ print(f"Couldn't initialize local model in tool agent: {e}")
34
+ self.local_model = None
35
+
36
  def run(self, query: str) -> str:
37
  """Process a query and return a response using available tools."""
38
+ print(f"Processing query: {query}")
39
  tools = self.tool_manager.get_tools()
40
 
41
  # For each tool, try to get relevant information
 
44
  for tool in tools:
45
  try:
46
  if self._should_use_tool(tool, query):
47
+ print(f"Using tool: {tool.name}")
48
  result = tool.forward(query)
49
  if result:
50
  context_info.append(f"{tool.name} Results:\n{result}")
 
54
  # Combine all context information
55
  full_context = "\n\n".join(context_info) if context_info else ""
56
 
57
+ # If we have context and a local model, generate a proper response
58
+ if full_context and self.local_model:
59
+ try:
60
+ prompt = f"""
61
+ Based on the following information, please provide a comprehensive answer to the question: "{query}"
62
+
63
+ CONTEXT INFORMATION:
64
+ {full_context}
65
+
66
+ Answer:
67
+ """
68
+
69
+ response = self.local_model.generate(prompt)
70
+ return response
71
+ except Exception as e:
72
+ print(f"Error generating response with local model: {e}")
73
+ # Fall back to returning just the context
74
+ return full_context
75
+ else:
76
+ # No context or no model, return whatever we have
77
+ if not full_context:
78
+ return "I couldn't find any relevant information to answer your question."
79
+ return full_context
80
 
81
  def __call__(self, query: str) -> str:
82
  """Make the agent callable so it can be used directly by CodeAgent."""
 
96
  "gaia_retriever": ["gaia", "agent", "ai", "artificial intelligence"]
97
  }
98
 
99
+ # Use all tools if patterns dict doesn't have the tool name
100
+ if tool.name not in patterns:
101
+ return True
102
+
103
  return any(pattern in query_lower for pattern in patterns.get(tool.name, []))
104
 
105
  def create_manager_agent() -> CodeAgent:
106
  """Create and configure the main GAIA agent."""
107
 
108
+ try:
109
+ # Import config for local model
110
+ from config import LOCAL_MODEL_CONFIG
111
+
112
+ # Use local model to avoid credit limits
113
+ model = LocalTransformersModel(
114
+ model_name=LOCAL_MODEL_CONFIG["model_name"],
115
+ device=LOCAL_MODEL_CONFIG["device"],
116
+ max_tokens=LOCAL_MODEL_CONFIG["max_tokens"],
117
+ temperature=LOCAL_MODEL_CONFIG["temperature"]
118
+ )
119
+ print(f"Using local model: {LOCAL_MODEL_CONFIG['model_name']}")
120
+ except Exception as e:
121
+ print(f"Error setting up local model: {e}")
122
+ # Use a simplified configuration as fallback
123
+ model = LocalTransformersModel(
124
+ model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.6",
125
+ device="cpu"
126
+ )
127
+ print("Using fallback model configuration")
128
+
129
+ # Initialize the managed tool-calling agent, sharing the model
130
+ tool_agent = GaiaToolCallingAgent(local_model=model)
131
 
132
  # Create the manager agent
133
  manager_agent = CodeAgent(
134
+ model=model,
135
  tools=[], # No direct tools for manager
136
  managed_agents=[tool_agent],
137
  additional_authorized_imports=[
138
  "json",
139
+ "pandas",
140
  "numpy",
141
  "re",
142
  "requests",
 
147
  max_steps=10
148
  )
149
 
150
+ print("Manager agent created with local model")
151
  return manager_agent
152
 
153
  def create_agent():
config.py CHANGED
@@ -7,6 +7,15 @@ LLAMA_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8
7
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
8
  HEADERS = {"Authorization": f"Bearer {HF_API_TOKEN}"} if HF_API_TOKEN else {}
9
 
 
 
 
 
 
 
 
 
 
10
  # --- Request Configuration ---
11
  MAX_RETRIES = 3
12
  RETRY_DELAY = 2 # seconds
@@ -65,3 +74,6 @@ ANSWER_PREFIXES_TO_REMOVE = [
65
 
66
  LLM_RESPONSE_MARKERS = ["<answer>", "<response>", "Answer:", "Response:", "Assistant:"]
67
  LLM_END_MARKERS = ["</answer>", "</response>", "Human:", "User:"]
 
 
 
 
7
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
8
  HEADERS = {"Authorization": f"Bearer {HF_API_TOKEN}"} if HF_API_TOKEN else {}
9
 
10
+ # --- Model Configuration ---
11
+ USE_LOCAL_MODEL = True # Set to False to use remote API model instead
12
+ LOCAL_MODEL_CONFIG = {
13
+ "model_name": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", # A small but capable model
14
+ "device": "auto", # Will use GPU if available
15
+ "max_tokens": 1024,
16
+ "temperature": 0.5
17
+ }
18
+
19
  # --- Request Configuration ---
20
  MAX_RETRIES = 3
21
  RETRY_DELAY = 2 # seconds
 
74
 
75
  LLM_RESPONSE_MARKERS = ["<answer>", "<response>", "Answer:", "Response:", "Assistant:"]
76
  LLM_END_MARKERS = ["</answer>", "</response>", "Human:", "User:"]
77
+
78
+ # Ensure knowledge base is loaded correctly
79
+ GAIA_KNOWLEDGE = load_knowledge_base()
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  gradio
2
  requests
3
  pandas
@@ -10,4 +11,9 @@ duckduckgo-search
10
  rank_bm25
11
  pytube
12
  python-dateutil
13
- youtube-transcript-api
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
  gradio
3
  requests
4
  pandas
 
11
  rank_bm25
12
  pytube
13
  python-dateutil
14
+ youtube-transcript-api
15
+ torch
16
+ transformers
17
+ torch==2.1.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
18
+ torchvision==0.14.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
19
+ torchaudio==0.10.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
test_agent.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
- from app import GaiaAgent
3
 
4
  # Initialize the agent
5
- agent = GaiaAgent()
 
6
 
7
  # Test cases from the logs that were failing
8
  test_questions = [
@@ -17,7 +18,7 @@ test_questions = [
17
  for question in test_questions:
18
  print(f"\nTesting question: {question}")
19
  try:
20
- answer = agent(question)
21
- print(f"Agent answer: {answer}")
22
  except Exception as e:
23
  print(f"Error: {e}")
 
1
  import os
2
+ from app import create_agent
3
 
4
  # Initialize the agent
5
+ print("Creating agent for testing...")
6
+ agent = create_agent()
7
 
8
  # Test cases from the logs that were failing
9
  test_questions = [
 
18
  for question in test_questions:
19
  print(f"\nTesting question: {question}")
20
  try:
21
+ response = agent.run(f"Answer this question concisely: {question}")
22
+ print(f"Agent answer: {response}")
23
  except Exception as e:
24
  print(f"Error: {e}")
utils/local_model.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom model implementation using Hugging Face Transformers.
3
+
4
+ This provides a local model implementation compatible with smolagents framework.
5
+ """
6
+
7
+ import logging
8
+ from typing import Dict, List, Optional, Any
9
+ from smolagents.models import Model
10
+ from transformers import AutoTokenizer, pipeline
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class LocalTransformersModel(Model):
15
+ """Model using local Hugging Face Transformers models that doesn't require API calls."""
16
+
17
+ def __init__(
18
+ self,
19
+ model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
20
+ device: str = "auto",
21
+ max_tokens: int = 512,
22
+ temperature: float = 0.7
23
+ ):
24
+ """
25
+ Initialize a local transformer model.
26
+
27
+ Args:
28
+ model_name: HuggingFace model identifier
29
+ device: "cpu", "cuda", "auto"
30
+ max_tokens: Maximum new tokens to generate
31
+ temperature: Sampling temperature
32
+ """
33
+ super().__init__()
34
+
35
+ try:
36
+ print(f"Loading model {model_name}...")
37
+
38
+ self.model_name = model_name
39
+ self.device = device
40
+ self.max_tokens = max_tokens
41
+ self.temperature = temperature
42
+
43
+ # Determine if we can use GPU
44
+ if device == "auto":
45
+ import torch
46
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
47
+
48
+ # Load tokenizer and pipeline
49
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
50
+
51
+ # Create text generation pipeline
52
+ self.generator = pipeline(
53
+ "text-generation",
54
+ model=model_name,
55
+ tokenizer=self.tokenizer,
56
+ device=self.device,
57
+ torch_dtype="auto"
58
+ )
59
+
60
+ print(f"Model loaded on {self.device}")
61
+
62
+ except Exception as e:
63
+ logger.error(f"Error loading model {model_name}: {e}")
64
+ print(f"Error loading model: {e}")
65
+ raise
66
+
67
+ def generate(self, prompt: str, **kwargs) -> str:
68
+ """
69
+ Generate text completion for the given prompt.
70
+
71
+ Args:
72
+ prompt: Input text
73
+
74
+ Returns:
75
+ Generated text completion
76
+ """
77
+ try:
78
+ print(f"Generating with prompt: {prompt[:50]}...")
79
+
80
+ # Actual generation
81
+ response = self.generator(
82
+ prompt,
83
+ max_new_tokens=self.max_tokens,
84
+ temperature=self.temperature,
85
+ do_sample=True,
86
+ pad_token_id=self.tokenizer.eos_token_id
87
+ )
88
+
89
+ # Extract generated text
90
+ generated_text = response[0]['generated_text']
91
+
92
+ # Remove the prompt from the beginning
93
+ if generated_text.startswith(prompt):
94
+ generated_text = generated_text[len(prompt):]
95
+
96
+ return generated_text.strip()
97
+
98
+ except Exception as e:
99
+ logger.error(f"Error generating text: {e}")
100
+ print(f"Error generating text: {e}")
101
+ return f"Error: {str(e)}"
102
+
103
+ def generate_with_tools(
104
+ self,
105
+ messages: List[Dict[str, Any]],
106
+ tools: Optional[List[Dict[str, Any]]] = None,
107
+ **kwargs
108
+ ) -> Dict[str, Any]:
109
+ """
110
+ Generate a response with tool-calling capabilities.
111
+ This method implements the smolagents BaseModel interface for tool-calling.
112
+
113
+ Args:
114
+ messages: List of message objects with role and content
115
+ tools: List of tool definitions
116
+
117
+ Returns:
118
+ Response with message and optional tool calls
119
+ """
120
+ try:
121
+ # Format messages into a prompt
122
+ prompt = self._format_messages_to_prompt(messages, tools)
123
+
124
+ # Generate response
125
+ completion = self.generate(prompt)
126
+
127
+ # For now, just return the text without tool parsing
128
+ # In a future enhancement, we could add tool parsing here
129
+ return {
130
+ "message": {
131
+ "role": "assistant",
132
+ "content": completion
133
+ }
134
+ }
135
+ except Exception as e:
136
+ logger.error(f"Error generating with tools: {e}")
137
+ print(f"Error generating with tools: {e}")
138
+ return {
139
+ "message": {
140
+ "role": "assistant",
141
+ "content": f"Error: {str(e)}"
142
+ }
143
+ }
144
+
145
+ def _format_messages_to_prompt(
146
+ self,
147
+ messages: List[Dict[str, Any]],
148
+ tools: Optional[List[Dict[str, Any]]] = None
149
+ ) -> str:
150
+ """Format chat messages into a text prompt for the model."""
151
+ formatted_prompt = ""
152
+
153
+ # Include tool descriptions if available
154
+ if tools and len(tools) > 0:
155
+ tool_descriptions = "\n".join([
156
+ f"Tool {i+1}: {tool['name']} - {tool['description']}"
157
+ for i, tool in enumerate(tools)
158
+ ])
159
+ formatted_prompt += f"Available tools:\n{tool_descriptions}\n\n"
160
+
161
+ # Add conversation history
162
+ for msg in messages:
163
+ role = msg.get("role", "")
164
+ content = msg.get("content", "")
165
+
166
+ if role == "system":
167
+ formatted_prompt += f"System: {content}\n\n"
168
+ elif role == "user":
169
+ formatted_prompt += f"User: {content}\n\n"
170
+ elif role == "assistant":
171
+ formatted_prompt += f"Assistant: {content}\n\n"
172
+
173
+ # Add final prompt for assistant
174
+ formatted_prompt += "Assistant: "
175
+
176
+ return formatted_prompt
177
+ # return f"Error generating response: {str(e)}"