File size: 6,141 Bytes
ed9acbe
 
 
 
 
 
9f08c4f
d9f0f18
455866a
ed9acbe
 
 
 
d9f0f18
ed9acbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6d2c9e
 
 
 
 
d9f0f18
a6d2c9e
e4f7d1f
a6d2c9e
 
391fe34
a94813f
 
 
 
 
d9f0f18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a94813f
d9f0f18
55b792c
78f6f4d
 
 
d9f0f18
 
3d16b77
d9f0f18
55b792c
d9f0f18
 
 
55b792c
d9f0f18
 
 
 
 
 
 
 
 
 
 
55b792c
d9f0f18
 
 
 
 
55b792c
 
 
 
 
d9f0f18
a94813f
 
ed9acbe
a94813f
e4f7d1f
ed9acbe
a94813f
e4f7d1f
 
a94813f
e4f7d1f
 
a94813f
ed9acbe
 
 
 
 
 
7acb2e7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
from smolagents import CodeAgent, ToolCallingAgent
from smolagents import OpenAIServerModel
from tools.fetch import fetch_webpage
from tools.yttranscript import get_youtube_transcript, get_youtube_title_description
import myprompts
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
t torch

# --- Basic Agent Definition ---
class BasicAgent:
    def __init__(self):
        print("BasicAgent initialized.")
    
    def __call__(self, question: str) -> str:
        print(f"Agent received question (first 50 chars): {question[:50]}...")

        try:
            # Use the reviewer agent to determine if the question can be answered by a model or requires code
            print("Calling reviewer agent...")
            reviewer_answer = reviewer_agent.run(myprompts.review_prompt + "\nThe question is:\n" + question)
            print(f"Reviewer agent answer: {reviewer_answer}")

            question = question + '\n' + myprompts.output_format
            fixed_answer = ""

            if reviewer_answer == "code":
                fixed_answer = gaia_agent.run(question)
                print(f"Code agent answer: {fixed_answer}")
                
            elif reviewer_answer == "model":    
                # If the reviewer agent suggests using the model, we can proceed with the model agent
                print("Using model agent to answer the question.")
                fixed_answer = model_agent.run(myprompts.model_prompt + "\nThe question is:\n" + question)
                print(f"Model agent answer: {fixed_answer}")

            return fixed_answer
        except Exception as e:
            error = f"An error occurred while processing the question: {e}"
            print(error)
            return error

# Load model and tokenizer
model_id = "LiquidAI/LFM2-1.2B"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,  # Fixed: was string, should be torch dtype
    trust_remote_code=True,
    # attn_implementation="flash_attention_2"  # <- uncomment on compatible GPU
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Create a wrapper class that matches the expected interface
class LocalLlamaModel:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.device = model.device if hasattr(model, 'device') else 'cpu'
    
    def _extract_text_from_messages(self, messages):
        """Extract text content from ChatMessage objects or handle string input"""
        if isinstance(messages, str):
            return messages
        elif isinstance(messages, list):
            # Handle list of ChatMessage objects
            text_parts = []
            for msg in messages:
                if hasattr(msg, 'content'):
                    # Handle ChatMessage with content attribute
                    if isinstance(msg.content, list):
                        # Content is a list of content items
                        for content_item in msg.content:
                            if isinstance(content_item, dict) and 'text' in content_item:
                                text_parts.append(content_item['text'])
                            elif hasattr(content_item, 'text'):
                                text_parts.append(content_item.text)
                    elif isinstance(msg.content, str):
                        text_parts.append(msg.content)
                elif isinstance(msg, dict) and 'content' in msg:
                    # Handle dictionary format
                    text_parts.append(str(msg['content']))
                else:
                    # Fallback: convert to string
                    text_parts.append(str(msg))
            return '\n'.join(text_parts)
        else:
            return str(messages)
    
    def generate(self, prompt, max_new_tokens=512*5, **kwargs):
        try:

            print("Prompt: ", prompt)
            print("Prompt type: ", type(prompt))
            # Extract text from the prompt (which might be ChatMessage objects)
            text_prompt = self._extract_text_from_messages(prompt)
            
            print("Extracted text prompt:", text_prompt[:200] + "..." if len(text_prompt) > 200 else text_prompt)
            
            # Tokenize the text prompt
            inputs = self.tokenizer(text_prompt, return_tensors="pt").to(self.model.device)
            input_ids = inputs['input_ids']
            
            # Generate output
            with torch.no_grad():
                output = self.model.generate(
                    input_ids,
                    do_sample=True,
                    temperature=0.3,
                    min_p=0.15,
                    repetition_penalty=1.05,
                    max_new_tokens=max_new_tokens,
                    pad_token_id=self.tokenizer.eos_token_id,  # Handle padding
                )
            
            # Decode only the new tokens (exclude the input)
            new_tokens = output[0][len(input_ids[0]):]
            response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
            
            return response.strip()
            
        except Exception as e:
            print(f"Error in model generation: {e}")
            return f"Error generating response: {str(e)}"

    def __call__(self, prompt, max_new_tokens=512, **kwargs):
        """Make the model callable like a function"""
        return self.generate(prompt, max_new_tokens, **kwargs)

# Create the model instance
wrapped_model = LocalLlamaModel(model, tokenizer)

# Now create your agents - these should work with the wrapped model
reviewer_agent = ToolCallingAgent(model=wrapped_model, tools=[])
model_agent = ToolCallingAgent(model=wrapped_model, tools=[fetch_webpage])
gaia_agent = CodeAgent(
    tools=[fetch_webpage, get_youtube_title_description, get_youtube_transcript],
    model=wrapped_model
)

if __name__ == "__main__":
    # Example usage
    question = "What was the actual enrollment of the Malko competition in 2023?"
    agent = BasicAgent()
    answer = agent(question)
    print(f"Answer: {answer}")