Spaces:
Sleeping
Sleeping
Commit
·
fc78ae4
1
Parent(s):
86dcf65
- changes for agent
Browse files- app.py +36 -21
- config.py +16 -2
- requirements.txt +3 -6
- utils/llama_cpp_model.py +296 -0
- utils/ollama_model.py +175 -0
app.py
CHANGED
@@ -6,10 +6,12 @@ from smolagents import Tool, CodeAgent, Model
|
|
6 |
|
7 |
# Import internal modules
|
8 |
from config import (
|
9 |
-
DEFAULT_API_URL
|
|
|
|
|
10 |
)
|
11 |
from tools.tool_manager import ToolManager
|
12 |
-
from utils.
|
13 |
|
14 |
class GaiaToolCallingAgent:
|
15 |
"""Tool-calling agent specifically designed for the GAIA system."""
|
@@ -24,9 +26,8 @@ class GaiaToolCallingAgent:
|
|
24 |
self.local_model = local_model
|
25 |
if not self.local_model:
|
26 |
try:
|
27 |
-
from utils.
|
28 |
-
self.local_model =
|
29 |
-
model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
30 |
max_tokens=512
|
31 |
)
|
32 |
except Exception as e:
|
@@ -106,25 +107,39 @@ def create_manager_agent() -> CodeAgent:
|
|
106 |
"""Create and configure the main GAIA agent."""
|
107 |
|
108 |
try:
|
109 |
-
# Import config for
|
110 |
-
from config import LOCAL_MODEL_CONFIG
|
111 |
|
112 |
-
# Use
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
except Exception as e:
|
121 |
-
print(f"Error setting up
|
122 |
# Use a simplified configuration as fallback
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
128 |
|
129 |
# Initialize the managed tool-calling agent, sharing the model
|
130 |
tool_agent = GaiaToolCallingAgent(local_model=model)
|
|
|
6 |
|
7 |
# Import internal modules
|
8 |
from config import (
|
9 |
+
DEFAULT_API_URL,
|
10 |
+
USE_LLAMACPP,
|
11 |
+
LLAMACPP_CONFIG
|
12 |
)
|
13 |
from tools.tool_manager import ToolManager
|
14 |
+
from utils.llama_cpp_model import LlamaCppModel
|
15 |
|
16 |
class GaiaToolCallingAgent:
|
17 |
"""Tool-calling agent specifically designed for the GAIA system."""
|
|
|
26 |
self.local_model = local_model
|
27 |
if not self.local_model:
|
28 |
try:
|
29 |
+
from utils.llama_cpp_model import LlamaCppModel
|
30 |
+
self.local_model = LlamaCppModel(
|
|
|
31 |
max_tokens=512
|
32 |
)
|
33 |
except Exception as e:
|
|
|
107 |
"""Create and configure the main GAIA agent."""
|
108 |
|
109 |
try:
|
110 |
+
# Import config for model
|
111 |
+
from config import LOCAL_MODEL_CONFIG, USE_LLAMACPP, LLAMACPP_CONFIG
|
112 |
|
113 |
+
# Use llama-cpp-python model (no PyTorch dependency)
|
114 |
+
if USE_LLAMACPP:
|
115 |
+
# Initialize llama-cpp model
|
116 |
+
model = LlamaCppModel(
|
117 |
+
model_path=LLAMACPP_CONFIG.get("model_path"),
|
118 |
+
model_url=LLAMACPP_CONFIG.get("model_url"),
|
119 |
+
n_ctx=LLAMACPP_CONFIG.get("n_ctx", 2048),
|
120 |
+
n_gpu_layers=LLAMACPP_CONFIG.get("n_gpu_layers", 0),
|
121 |
+
max_tokens=LLAMACPP_CONFIG.get("max_tokens", 512),
|
122 |
+
temperature=LLAMACPP_CONFIG.get("temperature", 0.7)
|
123 |
+
)
|
124 |
+
print(f"Using LlamaCpp model")
|
125 |
+
else:
|
126 |
+
# Use a simpler stub model if needed
|
127 |
+
from smolagents import StubModel
|
128 |
+
model = StubModel()
|
129 |
+
print("Using StubModel as fallback")
|
130 |
+
|
131 |
except Exception as e:
|
132 |
+
print(f"Error setting up model: {e}")
|
133 |
# Use a simplified configuration as fallback
|
134 |
+
try:
|
135 |
+
# Simple fallback with default params
|
136 |
+
model = LlamaCppModel()
|
137 |
+
print("Using fallback LlamaCpp model configuration")
|
138 |
+
except Exception as e2:
|
139 |
+
# Last resort fallback
|
140 |
+
from smolagents import StubModel
|
141 |
+
model = StubModel()
|
142 |
+
print(f"Using StubModel due to error: {e2}")
|
143 |
|
144 |
# Initialize the managed tool-calling agent, sharing the model
|
145 |
tool_agent = GaiaToolCallingAgent(local_model=model)
|
config.py
CHANGED
@@ -9,9 +9,23 @@ HEADERS = {"Authorization": f"Bearer {HF_API_TOKEN}"} if HF_API_TOKEN else {}
|
|
9 |
|
10 |
# --- Model Configuration ---
|
11 |
USE_LOCAL_MODEL = True # Set to False to use remote API model instead
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
LOCAL_MODEL_CONFIG = {
|
13 |
-
"model_name": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
14 |
-
"device": "
|
15 |
"max_tokens": 1024,
|
16 |
"temperature": 0.5
|
17 |
}
|
|
|
9 |
|
10 |
# --- Model Configuration ---
|
11 |
USE_LOCAL_MODEL = True # Set to False to use remote API model instead
|
12 |
+
USE_LLAMACPP = True # Set to True to use llama-cpp-python instead of transformers
|
13 |
+
|
14 |
+
# Configuration for llama-cpp-python model
|
15 |
+
LLAMACPP_CONFIG = {
|
16 |
+
"model_path": None, # Will use a default small model if None
|
17 |
+
# Using a smaller GGUF model to avoid download issues
|
18 |
+
"model_url": "https://huggingface.co/eachadea/ggml-gridlocked-alpha-3b/resolve/main/ggml-gridlocked-3b-q4_0.bin",
|
19 |
+
"n_ctx": 2048,
|
20 |
+
"n_gpu_layers": 0, # Use 0 for CPU-only
|
21 |
+
"max_tokens": 1024,
|
22 |
+
"temperature": 0.7
|
23 |
+
}
|
24 |
+
|
25 |
+
# Backup configuration for transformers model
|
26 |
LOCAL_MODEL_CONFIG = {
|
27 |
+
"model_name": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
28 |
+
"device": "cpu",
|
29 |
"max_tokens": 1024,
|
30 |
"temperature": 0.5
|
31 |
}
|
requirements.txt
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
--extra-index-url https://download.pytorch.org/whl/cpu
|
2 |
gradio
|
3 |
requests
|
4 |
pandas
|
@@ -12,8 +11,6 @@ rank_bm25
|
|
12 |
pytube
|
13 |
python-dateutil
|
14 |
youtube-transcript-api
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
# torchvision
|
19 |
-
# torchaudio
|
|
|
|
|
1 |
gradio
|
2 |
requests
|
3 |
pandas
|
|
|
11 |
pytube
|
12 |
python-dateutil
|
13 |
youtube-transcript-api
|
14 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
15 |
+
--find-links https://github.com/abetlen/llama-cpp-python/releases/latest
|
16 |
+
llama-cpp-python
|
|
|
|
utils/llama_cpp_model.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Fallback model implementation for testing when llama-cpp-python is not available.
|
3 |
+
|
4 |
+
This provides a compatible model class that doesn't require any external dependencies,
|
5 |
+
allowing the rest of the application to function while we solve the llama-cpp-python
|
6 |
+
installation issues.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import logging
|
11 |
+
from typing import Dict, List, Optional, Any, Union
|
12 |
+
import requests
|
13 |
+
from smolagents import Model
|
14 |
+
from pathlib import Path
|
15 |
+
|
16 |
+
# Try to import llama_cpp, but don't fail if not available
|
17 |
+
try:
|
18 |
+
from llama_cpp import Llama
|
19 |
+
from pathlib import Path
|
20 |
+
LLAMA_CPP_AVAILABLE = True
|
21 |
+
except ImportError:
|
22 |
+
LLAMA_CPP_AVAILABLE = False
|
23 |
+
print("llama_cpp module not available, using fallback implementation")
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
class LlamaCppModel(Model):
|
28 |
+
"""Model using llama.cpp Python bindings for efficient local inference without PyTorch.
|
29 |
+
Falls back to a simple text generation if llama_cpp is not available."""
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
model_path: str = None,
|
33 |
+
model_url: str = None,
|
34 |
+
n_ctx: int = 2048,
|
35 |
+
n_gpu_layers: int = 0,
|
36 |
+
max_tokens: int = 512,
|
37 |
+
temperature: float = 0.7,
|
38 |
+
verbose: bool = True
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
Initialize a local llama.cpp model or fallback to a simple implementation.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
model_path: Path to local GGUF model file
|
45 |
+
model_url: URL to download model if model_path doesn't exist
|
46 |
+
n_ctx: Context window size
|
47 |
+
n_gpu_layers: Number of layers to offload to GPU (0 means CPU only)
|
48 |
+
max_tokens: Maximum new tokens to generate
|
49 |
+
temperature: Sampling temperature
|
50 |
+
verbose: Whether to print verbose messages
|
51 |
+
"""
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.model_path = model_path
|
55 |
+
self.model_url = model_url
|
56 |
+
self.n_ctx = n_ctx
|
57 |
+
self.max_tokens = max_tokens
|
58 |
+
self.temperature = temperature
|
59 |
+
self.verbose = verbose
|
60 |
+
self.llm = None
|
61 |
+
|
62 |
+
# Check if we can use llama_cpp
|
63 |
+
if LLAMA_CPP_AVAILABLE:
|
64 |
+
try:
|
65 |
+
if self.verbose:
|
66 |
+
print("Attempting to initialize LlamaCpp model...")
|
67 |
+
|
68 |
+
# Try to initialize the real model
|
69 |
+
if model_path and os.path.exists(model_path):
|
70 |
+
if self.verbose:
|
71 |
+
print(f"Loading model from {model_path}...")
|
72 |
+
|
73 |
+
# Initialize the llama-cpp model
|
74 |
+
self.llm = Llama(
|
75 |
+
model_path=model_path,
|
76 |
+
n_ctx=n_ctx,
|
77 |
+
n_gpu_layers=n_gpu_layers,
|
78 |
+
verbose=verbose
|
79 |
+
)
|
80 |
+
|
81 |
+
if self.verbose:
|
82 |
+
print("LlamaCpp model loaded successfully")
|
83 |
+
else:
|
84 |
+
if self.verbose:
|
85 |
+
print(f"Model path not found or not specified. Using fallback mode.")
|
86 |
+
except Exception as e:
|
87 |
+
logger.error(f"Error initializing LlamaCpp model: {e}")
|
88 |
+
if self.verbose:
|
89 |
+
print(f"Error initializing LlamaCpp model: {e}")
|
90 |
+
self.llm = None
|
91 |
+
else:
|
92 |
+
if self.verbose:
|
93 |
+
print("LlamaCpp not available, using fallback implementation")
|
94 |
+
|
95 |
+
if not self.llm and self.verbose:
|
96 |
+
print("Using fallback text generation mode")
|
97 |
+
|
98 |
+
def _resolve_model_path(self, model_path: str = None, model_url: str = None) -> str:
|
99 |
+
"""
|
100 |
+
Resolve model path, downloading if necessary.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
Absolute path to model file
|
104 |
+
"""
|
105 |
+
# Default to a small model if none specified
|
106 |
+
if not model_path:
|
107 |
+
models_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
|
108 |
+
os.makedirs(models_dir, exist_ok=True)
|
109 |
+
model_path = os.path.join(models_dir, "ggml-model-q4_0.bin")
|
110 |
+
|
111 |
+
# Convert to Path for easier handling
|
112 |
+
path = Path(model_path)
|
113 |
+
|
114 |
+
# If model exists, return it
|
115 |
+
if path.exists():
|
116 |
+
return str(path.absolute())
|
117 |
+
|
118 |
+
# Download if URL provided
|
119 |
+
if model_url and not path.exists():
|
120 |
+
try:
|
121 |
+
print(f"Downloading model from {model_url}...")
|
122 |
+
os.makedirs(path.parent, exist_ok=True)
|
123 |
+
|
124 |
+
try:
|
125 |
+
# Try with streaming download first
|
126 |
+
with requests.get(model_url, stream=True, timeout=30) as r:
|
127 |
+
r.raise_for_status()
|
128 |
+
total_size = int(r.headers.get('content-length', 0))
|
129 |
+
block_size = 8192
|
130 |
+
|
131 |
+
with open(path, 'wb') as f:
|
132 |
+
downloaded = 0
|
133 |
+
for chunk in r.iter_content(chunk_size=block_size):
|
134 |
+
if chunk:
|
135 |
+
f.write(chunk)
|
136 |
+
downloaded += len(chunk)
|
137 |
+
if total_size > 0:
|
138 |
+
percent = (downloaded / total_size) * 100
|
139 |
+
if percent % 10 < (block_size / total_size) * 100:
|
140 |
+
print(f"Download progress: {int(percent)}%")
|
141 |
+
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
|
142 |
+
print(f"Streaming download timed out: {e}. Using a simpler approach...")
|
143 |
+
# Fall back to simpler download method
|
144 |
+
r = requests.get(model_url, timeout=60)
|
145 |
+
r.raise_for_status()
|
146 |
+
with open(path, 'wb') as f:
|
147 |
+
f.write(r.content)
|
148 |
+
print("Download complete with simple method")
|
149 |
+
|
150 |
+
print(f"Model download complete: {path}")
|
151 |
+
return str(path.absolute())
|
152 |
+
except Exception as e:
|
153 |
+
logger.error(f"Error downloading model: {e}")
|
154 |
+
print(f"Error downloading model: {e}")
|
155 |
+
print("Continuing with dummy model instead...")
|
156 |
+
# Create a small dummy model file so we can continue
|
157 |
+
with open(path, 'wb') as f:
|
158 |
+
f.write(b"DUMMY MODEL")
|
159 |
+
return str(path.absolute())
|
160 |
+
|
161 |
+
# If we get here without a model, create a dummy one
|
162 |
+
print(f"Model file not found at {model_path} and no URL provided. Creating dummy model...")
|
163 |
+
os.makedirs(path.parent, exist_ok=True)
|
164 |
+
with open(path, 'wb') as f:
|
165 |
+
f.write(b"DUMMY MODEL")
|
166 |
+
return str(path.absolute())
|
167 |
+
|
168 |
+
def generate(self, prompt: str, **kwargs) -> str:
|
169 |
+
"""
|
170 |
+
Generate text completion for the given prompt.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
prompt: Input text
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
Generated text completion
|
177 |
+
"""
|
178 |
+
try:
|
179 |
+
if self.verbose:
|
180 |
+
print(f"Generating with prompt: {prompt[:50]}...")
|
181 |
+
|
182 |
+
# If we have a real model, use it
|
183 |
+
if self.llm:
|
184 |
+
# Actual generation with llama-cpp
|
185 |
+
response = self.llm(
|
186 |
+
prompt=prompt,
|
187 |
+
max_tokens=self.max_tokens,
|
188 |
+
temperature=self.temperature,
|
189 |
+
echo=False # Don't include the prompt in the response
|
190 |
+
)
|
191 |
+
|
192 |
+
# Extract generated text
|
193 |
+
if not response:
|
194 |
+
return ""
|
195 |
+
|
196 |
+
if isinstance(response, dict):
|
197 |
+
generated_text = response.get('choices', [{}])[0].get('text', '')
|
198 |
+
else:
|
199 |
+
# List of responses
|
200 |
+
generated_text = response[0].get('text', '')
|
201 |
+
|
202 |
+
return generated_text.strip()
|
203 |
+
else:
|
204 |
+
# Fallback simple generation
|
205 |
+
if self.verbose:
|
206 |
+
print("Using fallback text generation")
|
207 |
+
|
208 |
+
# Extract key information from prompt
|
209 |
+
words = prompt.strip().split()
|
210 |
+
last_words = ' '.join(words[-10:]) if len(words) > 10 else prompt
|
211 |
+
|
212 |
+
# Simple response generation based on prompt content
|
213 |
+
if "?" in prompt:
|
214 |
+
return f"Based on the information provided, I believe the answer is related to {last_words}. This is a fallback response as the LLM model could not be loaded."
|
215 |
+
else:
|
216 |
+
return f"I understand you're asking about {last_words}. Since I'm running in fallback mode without a proper language model, I can only acknowledge your query but not provide a detailed response."
|
217 |
+
|
218 |
+
except Exception as e:
|
219 |
+
logger.error(f"Error generating text: {e}")
|
220 |
+
if self.verbose:
|
221 |
+
print(f"Error generating text: {e}")
|
222 |
+
return f"Error generating response: {str(e)}"
|
223 |
+
|
224 |
+
def generate_with_tools(
|
225 |
+
self,
|
226 |
+
messages: List[Dict[str, Any]],
|
227 |
+
tools: Optional[List[Dict[str, Any]]] = None,
|
228 |
+
**kwargs
|
229 |
+
) -> Dict[str, Any]:
|
230 |
+
"""
|
231 |
+
Generate a response with tool-calling capabilities.
|
232 |
+
This method implements the smolagents Model interface for tool-calling.
|
233 |
+
|
234 |
+
Args:
|
235 |
+
messages: List of message objects with role and content
|
236 |
+
tools: List of tool definitions
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
Response with message and optional tool calls
|
240 |
+
"""
|
241 |
+
try:
|
242 |
+
# Format messages into a prompt
|
243 |
+
prompt = self._format_messages_to_prompt(messages, tools)
|
244 |
+
|
245 |
+
# Generate response
|
246 |
+
completion = self.generate(prompt)
|
247 |
+
|
248 |
+
# For now, just return the text without tool parsing
|
249 |
+
return {
|
250 |
+
"message": {
|
251 |
+
"role": "assistant",
|
252 |
+
"content": completion
|
253 |
+
}
|
254 |
+
}
|
255 |
+
except Exception as e:
|
256 |
+
logger.error(f"Error generating with tools: {e}")
|
257 |
+
print(f"Error generating with tools: {e}")
|
258 |
+
return {
|
259 |
+
"message": {
|
260 |
+
"role": "assistant",
|
261 |
+
"content": f"Error: {str(e)}"
|
262 |
+
}
|
263 |
+
}
|
264 |
+
|
265 |
+
def _format_messages_to_prompt(
|
266 |
+
self,
|
267 |
+
messages: List[Dict[str, Any]],
|
268 |
+
tools: Optional[List[Dict[str, Any]]] = None
|
269 |
+
) -> str:
|
270 |
+
"""Format chat messages into a text prompt for the model."""
|
271 |
+
formatted_prompt = ""
|
272 |
+
|
273 |
+
# Include tool descriptions if available
|
274 |
+
if tools and len(tools) > 0:
|
275 |
+
tool_descriptions = "\n".join([
|
276 |
+
f"Tool {i+1}: {tool['name']} - {tool['description']}"
|
277 |
+
for i, tool in enumerate(tools)
|
278 |
+
])
|
279 |
+
formatted_prompt += f"Available tools:\n{tool_descriptions}\n\n"
|
280 |
+
|
281 |
+
# Add conversation history
|
282 |
+
for msg in messages:
|
283 |
+
role = msg.get("role", "")
|
284 |
+
content = msg.get("content", "")
|
285 |
+
|
286 |
+
if role == "system":
|
287 |
+
formatted_prompt += f"System: {content}\n\n"
|
288 |
+
elif role == "user":
|
289 |
+
formatted_prompt += f"User: {content}\n\n"
|
290 |
+
elif role == "assistant":
|
291 |
+
formatted_prompt += f"Assistant: {content}\n\n"
|
292 |
+
|
293 |
+
# Add final prompt for assistant
|
294 |
+
formatted_prompt += "Assistant: "
|
295 |
+
|
296 |
+
return formatted_prompt
|
utils/ollama_model.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Alternative model implementation using Ollama API.
|
3 |
+
|
4 |
+
This provides a local model implementation that doesn't require PyTorch,
|
5 |
+
by connecting to a locally running Ollama server.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import requests
|
10 |
+
from typing import Dict, List, Optional, Any
|
11 |
+
from smolagents.models import Model
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
class OllamaModel(Model):
|
16 |
+
"""Model using Ollama API for local inference without PyTorch dependency."""
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
model_name: str = "llama2",
|
21 |
+
api_base: str = "http://localhost:11434",
|
22 |
+
max_tokens: int = 512,
|
23 |
+
temperature: float = 0.7
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
Initialize a connection to local Ollama server.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
model_name: Ollama model name (e.g., llama2, mistral, gemma)
|
30 |
+
api_base: Base URL for Ollama API
|
31 |
+
max_tokens: Maximum new tokens to generate
|
32 |
+
temperature: Sampling temperature
|
33 |
+
"""
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
try:
|
37 |
+
self.model_name = model_name
|
38 |
+
self.api_base = api_base.rstrip('/')
|
39 |
+
self.max_tokens = max_tokens
|
40 |
+
self.temperature = temperature
|
41 |
+
|
42 |
+
# Test connection to Ollama
|
43 |
+
print(f"Testing connection to Ollama at {api_base}...")
|
44 |
+
response = requests.get(f"{self.api_base}/api/tags")
|
45 |
+
if response.status_code == 200:
|
46 |
+
models = [model["name"] for model in response.json().get("models", [])]
|
47 |
+
print(f"Available Ollama models: {models}")
|
48 |
+
if model_name not in models and models:
|
49 |
+
print(f"Warning: Model {model_name} not found. Available models: {models}")
|
50 |
+
print(f"Ollama connection successful")
|
51 |
+
else:
|
52 |
+
print(f"Warning: Ollama server not responding correctly. Status code: {response.status_code}")
|
53 |
+
|
54 |
+
except Exception as e:
|
55 |
+
logger.error(f"Error connecting to Ollama: {e}")
|
56 |
+
print(f"Error connecting to Ollama: {e}")
|
57 |
+
print("Make sure Ollama is installed and running. Visit https://ollama.ai for installation.")
|
58 |
+
raise
|
59 |
+
|
60 |
+
def generate(self, prompt: str, **kwargs) -> str:
|
61 |
+
"""
|
62 |
+
Generate text completion using Ollama API.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
prompt: Input text
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Generated text completion
|
69 |
+
"""
|
70 |
+
try:
|
71 |
+
print(f"Generating with prompt: {prompt[:50]}...")
|
72 |
+
|
73 |
+
# Prepare request
|
74 |
+
data = {
|
75 |
+
"model": self.model_name,
|
76 |
+
"prompt": prompt,
|
77 |
+
"stream": False,
|
78 |
+
"options": {
|
79 |
+
"temperature": self.temperature,
|
80 |
+
"num_predict": self.max_tokens
|
81 |
+
}
|
82 |
+
}
|
83 |
+
|
84 |
+
# Make API call
|
85 |
+
response = requests.post(
|
86 |
+
f"{self.api_base}/api/generate",
|
87 |
+
json=data
|
88 |
+
)
|
89 |
+
|
90 |
+
if response.status_code != 200:
|
91 |
+
error_msg = f"Ollama API error: {response.status_code} - {response.text}"
|
92 |
+
print(error_msg)
|
93 |
+
return error_msg
|
94 |
+
|
95 |
+
# Extract generated text
|
96 |
+
result = response.json()
|
97 |
+
return result.get("response", "No response received")
|
98 |
+
|
99 |
+
except Exception as e:
|
100 |
+
logger.error(f"Error generating text with Ollama: {e}")
|
101 |
+
print(f"Error generating text with Ollama: {e}")
|
102 |
+
return f"Error: {str(e)}"
|
103 |
+
|
104 |
+
def generate_with_tools(
|
105 |
+
self,
|
106 |
+
messages: List[Dict[str, Any]],
|
107 |
+
tools: Optional[List[Dict[str, Any]]] = None,
|
108 |
+
**kwargs
|
109 |
+
) -> Dict[str, Any]:
|
110 |
+
"""
|
111 |
+
Generate a response with tool-calling capabilities using Ollama.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
messages: List of message objects with role and content
|
115 |
+
tools: List of tool definitions
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Response with message and optional tool calls
|
119 |
+
"""
|
120 |
+
try:
|
121 |
+
# Format messages into a prompt
|
122 |
+
prompt = self._format_messages_to_prompt(messages, tools)
|
123 |
+
|
124 |
+
# Generate response
|
125 |
+
completion = self.generate(prompt)
|
126 |
+
|
127 |
+
# Return the formatted response
|
128 |
+
return {
|
129 |
+
"message": {
|
130 |
+
"role": "assistant",
|
131 |
+
"content": completion
|
132 |
+
}
|
133 |
+
}
|
134 |
+
except Exception as e:
|
135 |
+
logger.error(f"Error generating with tools: {e}")
|
136 |
+
print(f"Error generating with tools: {e}")
|
137 |
+
return {
|
138 |
+
"message": {
|
139 |
+
"role": "assistant",
|
140 |
+
"content": f"Error: {str(e)}"
|
141 |
+
}
|
142 |
+
}
|
143 |
+
|
144 |
+
def _format_messages_to_prompt(
|
145 |
+
self,
|
146 |
+
messages: List[Dict[str, Any]],
|
147 |
+
tools: Optional[List[Dict[str, Any]]] = None
|
148 |
+
) -> str:
|
149 |
+
"""Format chat messages into a text prompt for the model."""
|
150 |
+
formatted_prompt = ""
|
151 |
+
|
152 |
+
# Include tool descriptions if available
|
153 |
+
if tools and len(tools) > 0:
|
154 |
+
tool_descriptions = "\n".join([
|
155 |
+
f"Tool {i+1}: {tool['name']} - {tool['description']}"
|
156 |
+
for i, tool in enumerate(tools)
|
157 |
+
])
|
158 |
+
formatted_prompt += f"Available tools:\n{tool_descriptions}\n\n"
|
159 |
+
|
160 |
+
# Add conversation history
|
161 |
+
for msg in messages:
|
162 |
+
role = msg.get("role", "")
|
163 |
+
content = msg.get("content", "")
|
164 |
+
|
165 |
+
if role == "system":
|
166 |
+
formatted_prompt += f"System: {content}\n\n"
|
167 |
+
elif role == "user":
|
168 |
+
formatted_prompt += f"User: {content}\n\n"
|
169 |
+
elif role == "assistant":
|
170 |
+
formatted_prompt += f"Assistant: {content}\n\n"
|
171 |
+
|
172 |
+
# Add final prompt for assistant
|
173 |
+
formatted_prompt += "Assistant: "
|
174 |
+
|
175 |
+
return formatted_prompt
|