Spaces:
Sleeping
Sleeping
Commit
·
07ad0d5
1
Parent(s):
f5bafc2
- changes for local model
Browse files- app.py +72 -26
- config.py +12 -0
- requirements.txt +7 -1
- test_agent.py +5 -4
- utils/local_model.py +177 -0
app.py
CHANGED
@@ -1,44 +1,41 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import requests
|
4 |
-
import inspect
|
5 |
import pandas as pd
|
6 |
-
import
|
7 |
-
import json
|
8 |
-
import io
|
9 |
-
import base64
|
10 |
-
from typing import Dict, List, Union, Optional
|
11 |
-
import re
|
12 |
-
import sys
|
13 |
-
from bs4 import BeautifulSoup
|
14 |
-
from duckduckgo_search import DDGS
|
15 |
-
import pytube
|
16 |
-
from dateutil import parser
|
17 |
-
try:
|
18 |
-
from youtube_transcript_api import YouTubeTranscriptApi
|
19 |
-
except ImportError:
|
20 |
-
print("YouTube Transcript API not installed. Video transcription may be limited.")
|
21 |
-
|
22 |
-
from smolagents import Tool, CodeAgent, InferenceClientModel
|
23 |
|
24 |
# Import internal modules
|
25 |
from config import (
|
26 |
-
DEFAULT_API_URL
|
27 |
-
MAX_RETRIES, RETRY_DELAY
|
28 |
)
|
29 |
from tools.tool_manager import ToolManager
|
|
|
30 |
|
31 |
class GaiaToolCallingAgent:
|
32 |
"""Tool-calling agent specifically designed for the GAIA system."""
|
33 |
|
34 |
-
def __init__(self):
|
35 |
print("GaiaToolCallingAgent initialized.")
|
36 |
self.tool_manager = ToolManager()
|
37 |
self.name = "tool_agent" # Add required name attribute for smolagents integration
|
38 |
self.description = "A specialized agent that uses various tools to answer questions" # Required by smolagents
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def run(self, query: str) -> str:
|
41 |
"""Process a query and return a response using available tools."""
|
|
|
42 |
tools = self.tool_manager.get_tools()
|
43 |
|
44 |
# For each tool, try to get relevant information
|
@@ -47,6 +44,7 @@ class GaiaToolCallingAgent:
|
|
47 |
for tool in tools:
|
48 |
try:
|
49 |
if self._should_use_tool(tool, query):
|
|
|
50 |
result = tool.forward(query)
|
51 |
if result:
|
52 |
context_info.append(f"{tool.name} Results:\n{result}")
|
@@ -56,7 +54,29 @@ class GaiaToolCallingAgent:
|
|
56 |
# Combine all context information
|
57 |
full_context = "\n\n".join(context_info) if context_info else ""
|
58 |
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def __call__(self, query: str) -> str:
|
62 |
"""Make the agent callable so it can be used directly by CodeAgent."""
|
@@ -76,22 +96,47 @@ class GaiaToolCallingAgent:
|
|
76 |
"gaia_retriever": ["gaia", "agent", "ai", "artificial intelligence"]
|
77 |
}
|
78 |
|
|
|
|
|
|
|
|
|
79 |
return any(pattern in query_lower for pattern in patterns.get(tool.name, []))
|
80 |
|
81 |
def create_manager_agent() -> CodeAgent:
|
82 |
"""Create and configure the main GAIA agent."""
|
83 |
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
# Create the manager agent
|
88 |
manager_agent = CodeAgent(
|
89 |
-
model=
|
90 |
tools=[], # No direct tools for manager
|
91 |
managed_agents=[tool_agent],
|
92 |
additional_authorized_imports=[
|
93 |
"json",
|
94 |
-
"pandas",
|
95 |
"numpy",
|
96 |
"re",
|
97 |
"requests",
|
@@ -102,6 +147,7 @@ def create_manager_agent() -> CodeAgent:
|
|
102 |
max_steps=10
|
103 |
)
|
104 |
|
|
|
105 |
return manager_agent
|
106 |
|
107 |
def create_agent():
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import requests
|
|
|
4 |
import pandas as pd
|
5 |
+
from smolagents import Tool, CodeAgent, Model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Import internal modules
|
8 |
from config import (
|
9 |
+
DEFAULT_API_URL
|
|
|
10 |
)
|
11 |
from tools.tool_manager import ToolManager
|
12 |
+
from utils.local_model import LocalTransformersModel
|
13 |
|
14 |
class GaiaToolCallingAgent:
|
15 |
"""Tool-calling agent specifically designed for the GAIA system."""
|
16 |
|
17 |
+
def __init__(self, local_model=None):
|
18 |
print("GaiaToolCallingAgent initialized.")
|
19 |
self.tool_manager = ToolManager()
|
20 |
self.name = "tool_agent" # Add required name attribute for smolagents integration
|
21 |
self.description = "A specialized agent that uses various tools to answer questions" # Required by smolagents
|
22 |
|
23 |
+
# Use local model if provided, or create a simpler one
|
24 |
+
self.local_model = local_model
|
25 |
+
if not self.local_model:
|
26 |
+
try:
|
27 |
+
from utils.local_model import LocalTransformersModel
|
28 |
+
self.local_model = LocalTransformersModel(
|
29 |
+
model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
30 |
+
max_tokens=512
|
31 |
+
)
|
32 |
+
except Exception as e:
|
33 |
+
print(f"Couldn't initialize local model in tool agent: {e}")
|
34 |
+
self.local_model = None
|
35 |
+
|
36 |
def run(self, query: str) -> str:
|
37 |
"""Process a query and return a response using available tools."""
|
38 |
+
print(f"Processing query: {query}")
|
39 |
tools = self.tool_manager.get_tools()
|
40 |
|
41 |
# For each tool, try to get relevant information
|
|
|
44 |
for tool in tools:
|
45 |
try:
|
46 |
if self._should_use_tool(tool, query):
|
47 |
+
print(f"Using tool: {tool.name}")
|
48 |
result = tool.forward(query)
|
49 |
if result:
|
50 |
context_info.append(f"{tool.name} Results:\n{result}")
|
|
|
54 |
# Combine all context information
|
55 |
full_context = "\n\n".join(context_info) if context_info else ""
|
56 |
|
57 |
+
# If we have context and a local model, generate a proper response
|
58 |
+
if full_context and self.local_model:
|
59 |
+
try:
|
60 |
+
prompt = f"""
|
61 |
+
Based on the following information, please provide a comprehensive answer to the question: "{query}"
|
62 |
+
|
63 |
+
CONTEXT INFORMATION:
|
64 |
+
{full_context}
|
65 |
+
|
66 |
+
Answer:
|
67 |
+
"""
|
68 |
+
|
69 |
+
response = self.local_model.generate(prompt)
|
70 |
+
return response
|
71 |
+
except Exception as e:
|
72 |
+
print(f"Error generating response with local model: {e}")
|
73 |
+
# Fall back to returning just the context
|
74 |
+
return full_context
|
75 |
+
else:
|
76 |
+
# No context or no model, return whatever we have
|
77 |
+
if not full_context:
|
78 |
+
return "I couldn't find any relevant information to answer your question."
|
79 |
+
return full_context
|
80 |
|
81 |
def __call__(self, query: str) -> str:
|
82 |
"""Make the agent callable so it can be used directly by CodeAgent."""
|
|
|
96 |
"gaia_retriever": ["gaia", "agent", "ai", "artificial intelligence"]
|
97 |
}
|
98 |
|
99 |
+
# Use all tools if patterns dict doesn't have the tool name
|
100 |
+
if tool.name not in patterns:
|
101 |
+
return True
|
102 |
+
|
103 |
return any(pattern in query_lower for pattern in patterns.get(tool.name, []))
|
104 |
|
105 |
def create_manager_agent() -> CodeAgent:
|
106 |
"""Create and configure the main GAIA agent."""
|
107 |
|
108 |
+
try:
|
109 |
+
# Import config for local model
|
110 |
+
from config import LOCAL_MODEL_CONFIG
|
111 |
+
|
112 |
+
# Use local model to avoid credit limits
|
113 |
+
model = LocalTransformersModel(
|
114 |
+
model_name=LOCAL_MODEL_CONFIG["model_name"],
|
115 |
+
device=LOCAL_MODEL_CONFIG["device"],
|
116 |
+
max_tokens=LOCAL_MODEL_CONFIG["max_tokens"],
|
117 |
+
temperature=LOCAL_MODEL_CONFIG["temperature"]
|
118 |
+
)
|
119 |
+
print(f"Using local model: {LOCAL_MODEL_CONFIG['model_name']}")
|
120 |
+
except Exception as e:
|
121 |
+
print(f"Error setting up local model: {e}")
|
122 |
+
# Use a simplified configuration as fallback
|
123 |
+
model = LocalTransformersModel(
|
124 |
+
model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
125 |
+
device="cpu"
|
126 |
+
)
|
127 |
+
print("Using fallback model configuration")
|
128 |
+
|
129 |
+
# Initialize the managed tool-calling agent, sharing the model
|
130 |
+
tool_agent = GaiaToolCallingAgent(local_model=model)
|
131 |
|
132 |
# Create the manager agent
|
133 |
manager_agent = CodeAgent(
|
134 |
+
model=model,
|
135 |
tools=[], # No direct tools for manager
|
136 |
managed_agents=[tool_agent],
|
137 |
additional_authorized_imports=[
|
138 |
"json",
|
139 |
+
"pandas",
|
140 |
"numpy",
|
141 |
"re",
|
142 |
"requests",
|
|
|
147 |
max_steps=10
|
148 |
)
|
149 |
|
150 |
+
print("Manager agent created with local model")
|
151 |
return manager_agent
|
152 |
|
153 |
def create_agent():
|
config.py
CHANGED
@@ -7,6 +7,15 @@ LLAMA_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8
|
|
7 |
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
8 |
HEADERS = {"Authorization": f"Bearer {HF_API_TOKEN}"} if HF_API_TOKEN else {}
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
# --- Request Configuration ---
|
11 |
MAX_RETRIES = 3
|
12 |
RETRY_DELAY = 2 # seconds
|
@@ -65,3 +74,6 @@ ANSWER_PREFIXES_TO_REMOVE = [
|
|
65 |
|
66 |
LLM_RESPONSE_MARKERS = ["<answer>", "<response>", "Answer:", "Response:", "Assistant:"]
|
67 |
LLM_END_MARKERS = ["</answer>", "</response>", "Human:", "User:"]
|
|
|
|
|
|
|
|
7 |
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
8 |
HEADERS = {"Authorization": f"Bearer {HF_API_TOKEN}"} if HF_API_TOKEN else {}
|
9 |
|
10 |
+
# --- Model Configuration ---
|
11 |
+
USE_LOCAL_MODEL = True # Set to False to use remote API model instead
|
12 |
+
LOCAL_MODEL_CONFIG = {
|
13 |
+
"model_name": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", # A small but capable model
|
14 |
+
"device": "auto", # Will use GPU if available
|
15 |
+
"max_tokens": 1024,
|
16 |
+
"temperature": 0.5
|
17 |
+
}
|
18 |
+
|
19 |
# --- Request Configuration ---
|
20 |
MAX_RETRIES = 3
|
21 |
RETRY_DELAY = 2 # seconds
|
|
|
74 |
|
75 |
LLM_RESPONSE_MARKERS = ["<answer>", "<response>", "Answer:", "Response:", "Assistant:"]
|
76 |
LLM_END_MARKERS = ["</answer>", "</response>", "Human:", "User:"]
|
77 |
+
|
78 |
+
# Ensure knowledge base is loaded correctly
|
79 |
+
GAIA_KNOWLEDGE = load_knowledge_base()
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
gradio
|
2 |
requests
|
3 |
pandas
|
@@ -10,4 +11,9 @@ duckduckgo-search
|
|
10 |
rank_bm25
|
11 |
pytube
|
12 |
python-dateutil
|
13 |
-
youtube-transcript-api
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
2 |
gradio
|
3 |
requests
|
4 |
pandas
|
|
|
11 |
rank_bm25
|
12 |
pytube
|
13 |
python-dateutil
|
14 |
+
youtube-transcript-api
|
15 |
+
torch
|
16 |
+
transformers
|
17 |
+
torch==2.1.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
18 |
+
torchvision==0.14.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
19 |
+
torchaudio==0.10.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
test_agent.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import os
|
2 |
-
from app import
|
3 |
|
4 |
# Initialize the agent
|
5 |
-
agent
|
|
|
6 |
|
7 |
# Test cases from the logs that were failing
|
8 |
test_questions = [
|
@@ -17,7 +18,7 @@ test_questions = [
|
|
17 |
for question in test_questions:
|
18 |
print(f"\nTesting question: {question}")
|
19 |
try:
|
20 |
-
|
21 |
-
print(f"Agent answer: {
|
22 |
except Exception as e:
|
23 |
print(f"Error: {e}")
|
|
|
1 |
import os
|
2 |
+
from app import create_agent
|
3 |
|
4 |
# Initialize the agent
|
5 |
+
print("Creating agent for testing...")
|
6 |
+
agent = create_agent()
|
7 |
|
8 |
# Test cases from the logs that were failing
|
9 |
test_questions = [
|
|
|
18 |
for question in test_questions:
|
19 |
print(f"\nTesting question: {question}")
|
20 |
try:
|
21 |
+
response = agent.run(f"Answer this question concisely: {question}")
|
22 |
+
print(f"Agent answer: {response}")
|
23 |
except Exception as e:
|
24 |
print(f"Error: {e}")
|
utils/local_model.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Custom model implementation using Hugging Face Transformers.
|
3 |
+
|
4 |
+
This provides a local model implementation compatible with smolagents framework.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from typing import Dict, List, Optional, Any
|
9 |
+
from smolagents.models import Model
|
10 |
+
from transformers import AutoTokenizer, pipeline
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
class LocalTransformersModel(Model):
|
15 |
+
"""Model using local Hugging Face Transformers models that doesn't require API calls."""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
20 |
+
device: str = "auto",
|
21 |
+
max_tokens: int = 512,
|
22 |
+
temperature: float = 0.7
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Initialize a local transformer model.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
model_name: HuggingFace model identifier
|
29 |
+
device: "cpu", "cuda", "auto"
|
30 |
+
max_tokens: Maximum new tokens to generate
|
31 |
+
temperature: Sampling temperature
|
32 |
+
"""
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
try:
|
36 |
+
print(f"Loading model {model_name}...")
|
37 |
+
|
38 |
+
self.model_name = model_name
|
39 |
+
self.device = device
|
40 |
+
self.max_tokens = max_tokens
|
41 |
+
self.temperature = temperature
|
42 |
+
|
43 |
+
# Determine if we can use GPU
|
44 |
+
if device == "auto":
|
45 |
+
import torch
|
46 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
47 |
+
|
48 |
+
# Load tokenizer and pipeline
|
49 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
50 |
+
|
51 |
+
# Create text generation pipeline
|
52 |
+
self.generator = pipeline(
|
53 |
+
"text-generation",
|
54 |
+
model=model_name,
|
55 |
+
tokenizer=self.tokenizer,
|
56 |
+
device=self.device,
|
57 |
+
torch_dtype="auto"
|
58 |
+
)
|
59 |
+
|
60 |
+
print(f"Model loaded on {self.device}")
|
61 |
+
|
62 |
+
except Exception as e:
|
63 |
+
logger.error(f"Error loading model {model_name}: {e}")
|
64 |
+
print(f"Error loading model: {e}")
|
65 |
+
raise
|
66 |
+
|
67 |
+
def generate(self, prompt: str, **kwargs) -> str:
|
68 |
+
"""
|
69 |
+
Generate text completion for the given prompt.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
prompt: Input text
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Generated text completion
|
76 |
+
"""
|
77 |
+
try:
|
78 |
+
print(f"Generating with prompt: {prompt[:50]}...")
|
79 |
+
|
80 |
+
# Actual generation
|
81 |
+
response = self.generator(
|
82 |
+
prompt,
|
83 |
+
max_new_tokens=self.max_tokens,
|
84 |
+
temperature=self.temperature,
|
85 |
+
do_sample=True,
|
86 |
+
pad_token_id=self.tokenizer.eos_token_id
|
87 |
+
)
|
88 |
+
|
89 |
+
# Extract generated text
|
90 |
+
generated_text = response[0]['generated_text']
|
91 |
+
|
92 |
+
# Remove the prompt from the beginning
|
93 |
+
if generated_text.startswith(prompt):
|
94 |
+
generated_text = generated_text[len(prompt):]
|
95 |
+
|
96 |
+
return generated_text.strip()
|
97 |
+
|
98 |
+
except Exception as e:
|
99 |
+
logger.error(f"Error generating text: {e}")
|
100 |
+
print(f"Error generating text: {e}")
|
101 |
+
return f"Error: {str(e)}"
|
102 |
+
|
103 |
+
def generate_with_tools(
|
104 |
+
self,
|
105 |
+
messages: List[Dict[str, Any]],
|
106 |
+
tools: Optional[List[Dict[str, Any]]] = None,
|
107 |
+
**kwargs
|
108 |
+
) -> Dict[str, Any]:
|
109 |
+
"""
|
110 |
+
Generate a response with tool-calling capabilities.
|
111 |
+
This method implements the smolagents BaseModel interface for tool-calling.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
messages: List of message objects with role and content
|
115 |
+
tools: List of tool definitions
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Response with message and optional tool calls
|
119 |
+
"""
|
120 |
+
try:
|
121 |
+
# Format messages into a prompt
|
122 |
+
prompt = self._format_messages_to_prompt(messages, tools)
|
123 |
+
|
124 |
+
# Generate response
|
125 |
+
completion = self.generate(prompt)
|
126 |
+
|
127 |
+
# For now, just return the text without tool parsing
|
128 |
+
# In a future enhancement, we could add tool parsing here
|
129 |
+
return {
|
130 |
+
"message": {
|
131 |
+
"role": "assistant",
|
132 |
+
"content": completion
|
133 |
+
}
|
134 |
+
}
|
135 |
+
except Exception as e:
|
136 |
+
logger.error(f"Error generating with tools: {e}")
|
137 |
+
print(f"Error generating with tools: {e}")
|
138 |
+
return {
|
139 |
+
"message": {
|
140 |
+
"role": "assistant",
|
141 |
+
"content": f"Error: {str(e)}"
|
142 |
+
}
|
143 |
+
}
|
144 |
+
|
145 |
+
def _format_messages_to_prompt(
|
146 |
+
self,
|
147 |
+
messages: List[Dict[str, Any]],
|
148 |
+
tools: Optional[List[Dict[str, Any]]] = None
|
149 |
+
) -> str:
|
150 |
+
"""Format chat messages into a text prompt for the model."""
|
151 |
+
formatted_prompt = ""
|
152 |
+
|
153 |
+
# Include tool descriptions if available
|
154 |
+
if tools and len(tools) > 0:
|
155 |
+
tool_descriptions = "\n".join([
|
156 |
+
f"Tool {i+1}: {tool['name']} - {tool['description']}"
|
157 |
+
for i, tool in enumerate(tools)
|
158 |
+
])
|
159 |
+
formatted_prompt += f"Available tools:\n{tool_descriptions}\n\n"
|
160 |
+
|
161 |
+
# Add conversation history
|
162 |
+
for msg in messages:
|
163 |
+
role = msg.get("role", "")
|
164 |
+
content = msg.get("content", "")
|
165 |
+
|
166 |
+
if role == "system":
|
167 |
+
formatted_prompt += f"System: {content}\n\n"
|
168 |
+
elif role == "user":
|
169 |
+
formatted_prompt += f"User: {content}\n\n"
|
170 |
+
elif role == "assistant":
|
171 |
+
formatted_prompt += f"Assistant: {content}\n\n"
|
172 |
+
|
173 |
+
# Add final prompt for assistant
|
174 |
+
formatted_prompt += "Assistant: "
|
175 |
+
|
176 |
+
return formatted_prompt
|
177 |
+
# return f"Error generating response: {str(e)}"
|