|
from smolagents.models import Model, ChatMessage, Tool, MessageRole |
|
from time import time |
|
|
|
class FakeModelClass(Model): |
|
"""A model class that returns pre-recorded responses from a log file. |
|
|
|
This class is useful for testing and debugging purposes, as it doesn't make |
|
actual API calls but instead returns responses from a pre-recorded log file. |
|
|
|
Parameters: |
|
log_url (str, optional): |
|
URL to the log file. Defaults to the smolagents example log. |
|
**kwargs: Additional keyword arguments passed to the Model base class. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
log_folder: str, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.dataset_name = "smolagents/computer-agent-logs", |
|
self.log_folder = log_folder |
|
self.call_counter = 0 |
|
self.model_outputs = self._load_model_outputs() |
|
|
|
def _load_model_outputs(self) -> List[str]: |
|
"""Load model outputs from the log file using HuggingFace datasets library.""" |
|
|
|
file_path = hf_hub_download( |
|
repo_id=self.dataset_name, |
|
filename=self.log_folder + "/metadata.json", |
|
repo_type="dataset" |
|
) |
|
|
|
|
|
with open(file_path, 'r') as f: |
|
log_data = json.load(f) |
|
|
|
|
|
model_outputs = [] |
|
|
|
for step in log_data.get("tool_calls", []): |
|
if "model_output_message" in step: |
|
model_outputs.append(step["model_output_message"]) |
|
|
|
print(f"Loaded {len(model_outputs)} model outputs from log file") |
|
return model_outputs |
|
|
|
def __call__( |
|
self, |
|
messages: List[Dict[str, str]], |
|
stop_sequences: Optional[List[str]] = None, |
|
grammar: Optional[str] = None, |
|
tools_to_call_from: Optional[List[Tool]] = None, |
|
**kwargs |
|
) -> ChatMessage: |
|
"""Return the next pre-recorded response from the log file. |
|
|
|
Parameters: |
|
messages: List of input messages (ignored). |
|
stop_sequences: Optional list of stop sequences (ignored). |
|
grammar: Optional grammar specification (ignored). |
|
tools_to_call_from: Optional list of tools (ignored). |
|
**kwargs: Additional keyword arguments (ignored). |
|
|
|
Returns: |
|
ChatMessage: The next pre-recorded response. |
|
""" |
|
time.sleep(1.0) |
|
|
|
|
|
if self.call_counter < len(self.model_outputs): |
|
content = self.model_outputs[self.call_counter] |
|
self.call_counter += 1 |
|
else: |
|
content = "No more pre-recorded responses available." |
|
|
|
|
|
self.last_input_token_count = len(str(messages)) // 4 |
|
self.last_output_token_count = len(content) // 4 |
|
|
|
|
|
return ChatMessage( |
|
role=MessageRole.ASSISTANT, |
|
content=content, |
|
tool_calls=None, |
|
raw={"source": "pre-recorded log", "call_number": self.call_counter} |
|
) |
|
|