GaiaAgentEvaluator / utils /local_model.py
davidgturner's picture
- agent local model cpp
66d6d1f
"""
Custom model implementation using Hugging Face Transformers.
This provides a local model implementation compatible with smolagents framework.
"""
import logging
from typing import Dict, List, Optional, Any
from smolagents.models import Model
from transformers import AutoTokenizer, pipeline
logger = logging.getLogger(__name__)
class LocalTransformersModel(Model):
"""Model using local Hugging Face Transformers models that doesn't require API calls."""
def __init__(
self,
model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
device: str = "auto",
max_tokens: int = 512,
temperature: float = 0.7
):
"""
Initialize a local transformer model.
Args:
model_name: HuggingFace model identifier
device: "cpu", "cuda", "auto"
max_tokens: Maximum new tokens to generate
temperature: Sampling temperature
"""
super().__init__()
try:
print(f"Loading model {model_name}...")
self.model_name = model_name
self.device = device
self.max_tokens = max_tokens
self.temperature = temperature
# Determine if we can use GPU
if device == "auto":
import torch
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer and pipeline
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Create text generation pipeline
self.generator = pipeline(
"text-generation",
model=model_name,
tokenizer=self.tokenizer,
device=self.device,
torch_dtype="auto"
)
print(f"Model loaded on {self.device}")
except Exception as e:
logger.error(f"Error loading model {model_name}: {e}")
print(f"Error loading model: {e}")
raise
def generate(self, prompt: str, **kwargs) -> str:
"""
Generate text completion for the given prompt.
Args:
prompt: Input text
Returns:
Generated text completion
"""
try:
print(f"Generating with prompt: {prompt[:50]}...")
# Actual generation
response = self.generator(
prompt,
max_new_tokens=self.max_tokens,
temperature=self.temperature,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# Extract generated text
generated_text = response[0]['generated_text']
# Remove the prompt from the beginning
if generated_text.startswith(prompt):
generated_text = generated_text[len(prompt):]
return generated_text.strip()
except Exception as e:
error_msg = f"Error generating text (Local model): {e}"
logger.error(error_msg)
print(error_msg)
return f"Error: {str(e)}"
def generate_with_tools(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs
) -> Dict[str, Any]:
"""
Generate a response with tool-calling capabilities.
This method implements the smolagents BaseModel interface for tool-calling.
Args:
messages: List of message objects with role and content
tools: List of tool definitions
Returns:
Response with message and optional tool calls
"""
try:
# Format messages into a prompt
prompt = self._format_messages_to_prompt(messages, tools)
# Generate response
completion = self.generate(prompt)
# For now, just return the text without tool parsing
# In a future enhancement, we could add tool parsing here
return {
"message": {
"role": "assistant",
"content": completion
}
}
except Exception as e:
logger.error(f"Error generating with tools: {e}")
print(f"Error generating with tools: {e}")
return {
"message": {
"role": "assistant",
"content": f"Error: {str(e)}"
}
}
def _format_messages_to_prompt(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None
) -> str:
"""Format chat messages into a text prompt for the model."""
formatted_prompt = ""
# Include tool descriptions if available
if tools and len(tools) > 0:
tool_descriptions = "\n".join([
f"Tool {i+1}: {tool['name']} - {tool['description']}"
for i, tool in enumerate(tools)
])
formatted_prompt += f"Available tools:\n{tool_descriptions}\n\n"
# Add conversation history
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
formatted_prompt += f"System: {content}\n\n"
elif role == "user":
formatted_prompt += f"User: {content}\n\n"
elif role == "assistant":
formatted_prompt += f"Assistant: {content}\n\n"
# Add final prompt for assistant
formatted_prompt += "Assistant: "
return formatted_prompt
# return f"Error generating response: {str(e)}"