|
import os |
|
from smolagents import CodeAgent, ToolCallingAgent |
|
from smolagents import OpenAIServerModel |
|
from tools.fetch import fetch_webpage |
|
from tools.yttranscript import get_youtube_transcript, get_youtube_title_description |
|
import myprompts |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
t torch |
|
|
|
|
|
class BasicAgent: |
|
def __init__(self): |
|
print("BasicAgent initialized.") |
|
|
|
def __call__(self, question: str) -> str: |
|
print(f"Agent received question (first 50 chars): {question[:50]}...") |
|
|
|
try: |
|
|
|
print("Calling reviewer agent...") |
|
reviewer_answer = reviewer_agent.run(myprompts.review_prompt + "\nThe question is:\n" + question) |
|
print(f"Reviewer agent answer: {reviewer_answer}") |
|
|
|
question = question + '\n' + myprompts.output_format |
|
fixed_answer = "" |
|
|
|
if reviewer_answer == "code": |
|
fixed_answer = gaia_agent.run(question) |
|
print(f"Code agent answer: {fixed_answer}") |
|
|
|
elif reviewer_answer == "model": |
|
|
|
print("Using model agent to answer the question.") |
|
fixed_answer = model_agent.run(myprompts.model_prompt + "\nThe question is:\n" + question) |
|
print(f"Model agent answer: {fixed_answer}") |
|
|
|
return fixed_answer |
|
except Exception as e: |
|
error = f"An error occurred while processing the question: {e}" |
|
print(error) |
|
return error |
|
|
|
|
|
model_id = "LiquidAI/LFM2-1.2B" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
|
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
|
|
class LocalLlamaModel: |
|
def __init__(self, model, tokenizer): |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.device = model.device if hasattr(model, 'device') else 'cpu' |
|
|
|
def _extract_text_from_messages(self, messages): |
|
"""Extract text content from ChatMessage objects or handle string input""" |
|
if isinstance(messages, str): |
|
return messages |
|
elif isinstance(messages, list): |
|
|
|
text_parts = [] |
|
for msg in messages: |
|
if hasattr(msg, 'content'): |
|
|
|
if isinstance(msg.content, list): |
|
|
|
for content_item in msg.content: |
|
if isinstance(content_item, dict) and 'text' in content_item: |
|
text_parts.append(content_item['text']) |
|
elif hasattr(content_item, 'text'): |
|
text_parts.append(content_item.text) |
|
elif isinstance(msg.content, str): |
|
text_parts.append(msg.content) |
|
elif isinstance(msg, dict) and 'content' in msg: |
|
|
|
text_parts.append(str(msg['content'])) |
|
else: |
|
|
|
text_parts.append(str(msg)) |
|
return '\n'.join(text_parts) |
|
else: |
|
return str(messages) |
|
|
|
def generate(self, prompt, max_new_tokens=512*5, **kwargs): |
|
try: |
|
|
|
print("Prompt: ", prompt) |
|
print("Prompt type: ", type(prompt)) |
|
|
|
text_prompt = self._extract_text_from_messages(prompt) |
|
|
|
print("Extracted text prompt:", text_prompt[:200] + "..." if len(text_prompt) > 200 else text_prompt) |
|
|
|
|
|
inputs = self.tokenizer(text_prompt, return_tensors="pt").to(self.model.device) |
|
input_ids = inputs['input_ids'] |
|
|
|
|
|
with torch.no_grad(): |
|
output = self.model.generate( |
|
input_ids, |
|
do_sample=True, |
|
temperature=0.3, |
|
min_p=0.15, |
|
repetition_penalty=1.05, |
|
max_new_tokens=max_new_tokens, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
) |
|
|
|
|
|
new_tokens = output[0][len(input_ids[0]):] |
|
response = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
|
return response.strip() |
|
|
|
except Exception as e: |
|
print(f"Error in model generation: {e}") |
|
return f"Error generating response: {str(e)}" |
|
|
|
def __call__(self, prompt, max_new_tokens=512, **kwargs): |
|
"""Make the model callable like a function""" |
|
return self.generate(prompt, max_new_tokens, **kwargs) |
|
|
|
|
|
wrapped_model = LocalLlamaModel(model, tokenizer) |
|
|
|
|
|
reviewer_agent = ToolCallingAgent(model=wrapped_model, tools=[]) |
|
model_agent = ToolCallingAgent(model=wrapped_model, tools=[fetch_webpage]) |
|
gaia_agent = CodeAgent( |
|
tools=[fetch_webpage, get_youtube_title_description, get_youtube_transcript], |
|
model=wrapped_model |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
question = "What was the actual enrollment of the Malko competition in 2023?" |
|
agent = BasicAgent() |
|
answer = agent(question) |
|
print(f"Answer: {answer}") |