Spaces:
Sleeping
Sleeping
""" | |
Alternative model implementation using Ollama API. | |
This provides a local model implementation that doesn't require PyTorch, | |
by connecting to a locally running Ollama server. | |
""" | |
import logging | |
import requests | |
from typing import Dict, List, Optional, Any | |
from smolagents.models import Model | |
logger = logging.getLogger(__name__) | |
class OllamaModel(Model): | |
"""Model using Ollama API for local inference without PyTorch dependency.""" | |
def __init__( | |
self, | |
model_name: str = "llama2", | |
api_base: str = "http://localhost:11434", | |
max_tokens: int = 512, | |
temperature: float = 0.7 | |
): | |
""" | |
Initialize a connection to local Ollama server. | |
Args: | |
model_name: Ollama model name (e.g., llama2, mistral, gemma) | |
api_base: Base URL for Ollama API | |
max_tokens: Maximum new tokens to generate | |
temperature: Sampling temperature | |
""" | |
super().__init__() | |
try: | |
self.model_name = model_name | |
self.api_base = api_base.rstrip('/') | |
self.max_tokens = max_tokens | |
self.temperature = temperature | |
# Test connection to Ollama | |
print(f"Testing connection to Ollama at {api_base}...") | |
response = requests.get(f"{self.api_base}/api/tags") | |
if response.status_code == 200: | |
models = [model["name"] for model in response.json().get("models", [])] | |
print(f"Available Ollama models: {models}") | |
if model_name not in models and models: | |
print(f"Warning: Model {model_name} not found. Available models: {models}") | |
print(f"Ollama connection successful") | |
else: | |
print(f"Warning: Ollama server not responding correctly. Status code: {response.status_code}") | |
except Exception as e: | |
logger.error(f"Error connecting to Ollama: {e}") | |
print(f"Error connecting to Ollama: {e}") | |
print("Make sure Ollama is installed and running. Visit https://ollama.ai for installation.") | |
raise | |
def generate(self, prompt: str, **kwargs) -> str: | |
""" | |
Generate text completion using Ollama API. | |
Args: | |
prompt: Input text | |
Returns: | |
Generated text completion | |
""" | |
try: | |
print(f"Generating with prompt: {prompt[:50]}...") | |
# Prepare request | |
data = { | |
"model": self.model_name, | |
"prompt": prompt, | |
"stream": False, | |
"options": { | |
"temperature": self.temperature, | |
"num_predict": self.max_tokens | |
} | |
} | |
# Make API call | |
response = requests.post( | |
f"{self.api_base}/api/generate", | |
json=data | |
) | |
if response.status_code != 200: | |
error_msg = f"Ollama API error: {response.status_code} - {response.text}" | |
print(error_msg) | |
return error_msg | |
# Extract generated text | |
result = response.json() | |
return result.get("response", "No response received") | |
except Exception as e: | |
logger.error(f"Error generating text with Ollama: {e}") | |
print(f"Error generating text with Ollama: {e}") | |
return f"Error: {str(e)}" | |
def generate_with_tools( | |
self, | |
messages: List[Dict[str, Any]], | |
tools: Optional[List[Dict[str, Any]]] = None, | |
**kwargs | |
) -> Dict[str, Any]: | |
""" | |
Generate a response with tool-calling capabilities using Ollama. | |
Args: | |
messages: List of message objects with role and content | |
tools: List of tool definitions | |
Returns: | |
Response with message and optional tool calls | |
""" | |
try: | |
# Format messages into a prompt | |
prompt = self._format_messages_to_prompt(messages, tools) | |
# Generate response | |
completion = self.generate(prompt) | |
# Return the formatted response | |
return { | |
"message": { | |
"role": "assistant", | |
"content": completion | |
} | |
} | |
except Exception as e: | |
logger.error(f"Error generating with tools: {e}") | |
print(f"Error generating with tools: {e}") | |
return { | |
"message": { | |
"role": "assistant", | |
"content": f"Error: {str(e)}" | |
} | |
} | |
def _format_messages_to_prompt( | |
self, | |
messages: List[Dict[str, Any]], | |
tools: Optional[List[Dict[str, Any]]] = None | |
) -> str: | |
"""Format chat messages into a text prompt for the model.""" | |
formatted_prompt = "" | |
# Include tool descriptions if available | |
if tools and len(tools) > 0: | |
tool_descriptions = "\n".join([ | |
f"Tool {i+1}: {tool['name']} - {tool['description']}" | |
for i, tool in enumerate(tools) | |
]) | |
formatted_prompt += f"Available tools:\n{tool_descriptions}\n\n" | |
# Add conversation history | |
for msg in messages: | |
role = msg.get("role", "") | |
content = msg.get("content", "") | |
if role == "system": | |
formatted_prompt += f"System: {content}\n\n" | |
elif role == "user": | |
formatted_prompt += f"User: {content}\n\n" | |
elif role == "assistant": | |
formatted_prompt += f"Assistant: {content}\n\n" | |
# Add final prompt for assistant | |
formatted_prompt += "Assistant: " | |
return formatted_prompt | |