Spaces:
Sleeping
Sleeping
""" | |
Custom model implementation using Hugging Face Transformers. | |
This provides a local model implementation compatible with smolagents framework. | |
""" | |
import logging | |
from typing import Dict, List, Optional, Any | |
from smolagents.models import Model | |
from transformers import AutoTokenizer, pipeline | |
logger = logging.getLogger(__name__) | |
class LocalTransformersModel(Model): | |
"""Model using local Hugging Face Transformers models that doesn't require API calls.""" | |
def __init__( | |
self, | |
model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
device: str = "auto", | |
max_tokens: int = 512, | |
temperature: float = 0.7 | |
): | |
""" | |
Initialize a local transformer model. | |
Args: | |
model_name: HuggingFace model identifier | |
device: "cpu", "cuda", "auto" | |
max_tokens: Maximum new tokens to generate | |
temperature: Sampling temperature | |
""" | |
super().__init__() | |
try: | |
print(f"Loading model {model_name}...") | |
self.model_name = model_name | |
self.device = device | |
self.max_tokens = max_tokens | |
self.temperature = temperature | |
# Determine if we can use GPU | |
if device == "auto": | |
import torch | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load tokenizer and pipeline | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Create text generation pipeline | |
self.generator = pipeline( | |
"text-generation", | |
model=model_name, | |
tokenizer=self.tokenizer, | |
device=self.device, | |
torch_dtype="auto" | |
) | |
print(f"Model loaded on {self.device}") | |
except Exception as e: | |
logger.error(f"Error loading model {model_name}: {e}") | |
print(f"Error loading model: {e}") | |
raise | |
def generate(self, prompt: str, **kwargs) -> str: | |
""" | |
Generate text completion for the given prompt. | |
Args: | |
prompt: Input text | |
Returns: | |
Generated text completion | |
""" | |
try: | |
print(f"Generating with prompt: {prompt[:50]}...") | |
# Actual generation | |
response = self.generator( | |
prompt, | |
max_new_tokens=self.max_tokens, | |
temperature=self.temperature, | |
do_sample=True, | |
pad_token_id=self.tokenizer.eos_token_id | |
) | |
# Extract generated text | |
generated_text = response[0]['generated_text'] | |
# Remove the prompt from the beginning | |
if generated_text.startswith(prompt): | |
generated_text = generated_text[len(prompt):] | |
return generated_text.strip() | |
except Exception as e: | |
error_msg = f"Error generating text (Local model): {e}" | |
logger.error(error_msg) | |
print(error_msg) | |
return f"Error: {str(e)}" | |
def generate_with_tools( | |
self, | |
messages: List[Dict[str, Any]], | |
tools: Optional[List[Dict[str, Any]]] = None, | |
**kwargs | |
) -> Dict[str, Any]: | |
""" | |
Generate a response with tool-calling capabilities. | |
This method implements the smolagents BaseModel interface for tool-calling. | |
Args: | |
messages: List of message objects with role and content | |
tools: List of tool definitions | |
Returns: | |
Response with message and optional tool calls | |
""" | |
try: | |
# Format messages into a prompt | |
prompt = self._format_messages_to_prompt(messages, tools) | |
# Generate response | |
completion = self.generate(prompt) | |
# For now, just return the text without tool parsing | |
# In a future enhancement, we could add tool parsing here | |
return { | |
"message": { | |
"role": "assistant", | |
"content": completion | |
} | |
} | |
except Exception as e: | |
logger.error(f"Error generating with tools: {e}") | |
print(f"Error generating with tools: {e}") | |
return { | |
"message": { | |
"role": "assistant", | |
"content": f"Error: {str(e)}" | |
} | |
} | |
def _format_messages_to_prompt( | |
self, | |
messages: List[Dict[str, Any]], | |
tools: Optional[List[Dict[str, Any]]] = None | |
) -> str: | |
"""Format chat messages into a text prompt for the model.""" | |
formatted_prompt = "" | |
# Include tool descriptions if available | |
if tools and len(tools) > 0: | |
tool_descriptions = "\n".join([ | |
f"Tool {i+1}: {tool['name']} - {tool['description']}" | |
for i, tool in enumerate(tools) | |
]) | |
formatted_prompt += f"Available tools:\n{tool_descriptions}\n\n" | |
# Add conversation history | |
for msg in messages: | |
role = msg.get("role", "") | |
content = msg.get("content", "") | |
if role == "system": | |
formatted_prompt += f"System: {content}\n\n" | |
elif role == "user": | |
formatted_prompt += f"User: {content}\n\n" | |
elif role == "assistant": | |
formatted_prompt += f"Assistant: {content}\n\n" | |
# Add final prompt for assistant | |
formatted_prompt += "Assistant: " | |
return formatted_prompt | |
# return f"Error generating response: {str(e)}" |