Kunal Pai
Refactor LambdaAgent to use OpenAI client and update cost manager with new Lambda model expenses
fcdfb63
from abc import ABC, abstractmethod | |
from typing import Dict, Type, Any, Optional, Tuple | |
import os | |
import json | |
import ollama | |
from openai import OpenAI | |
from src.manager.utils.singleton import singleton | |
from src.manager.utils.streamlit_interface import output_assistant_response | |
from google import genai | |
from google.genai import types | |
from google.genai.types import * | |
from groq import Groq | |
import os | |
from dotenv import load_dotenv | |
from src.manager.budget_manager import BudgetManager | |
MODEL_PATH = "./src/models/" | |
MODEL_FILE_PATH = "./src/models/models.json" | |
class Agent(ABC): | |
def __init__(self, agent_name: str, | |
base_model: str, | |
system_prompt: str, | |
create_resource_cost: int, | |
invoke_resource_cost: int, | |
create_expense_cost: int = 0, | |
invoke_expense_cost: int = 0, | |
output_expense_cost: int = 0): | |
self.agent_name = agent_name | |
self.base_model = base_model | |
self.system_prompt = system_prompt | |
self.create_resource_cost = create_resource_cost | |
self.invoke_resource_cost = invoke_resource_cost | |
self.create_expense_cost = create_expense_cost | |
self.invoke_expense_cost = invoke_expense_cost | |
self.output_expense_cost = output_expense_cost | |
self.create_model() | |
def create_model(self) -> None: | |
"""Create and Initialize agent""" | |
pass | |
def ask_agent(self, prompt: str) -> str: | |
"""ask agent a question""" | |
pass | |
def delete_agent(self) -> None: | |
"""delete agent""" | |
pass | |
def get_type(self) -> None: | |
"""get agent type""" | |
pass | |
def get_costs(self): | |
return { | |
"create_resource_cost": self.create_resource_cost, | |
"invoke_resource_cost": self.invoke_resource_cost, | |
"create_expense_cost": self.create_expense_cost, | |
"invoke_expense_cost": self.invoke_expense_cost, | |
"output_expense_cost": self.output_expense_cost, | |
} | |
class OllamaAgent(Agent): | |
type = "local" | |
def create_model(self): | |
ollama_response = ollama.create( | |
model=self.agent_name, | |
from_=self.base_model, | |
system=self.system_prompt, | |
stream=False | |
) | |
def ask_agent(self, prompt): | |
output_assistant_response(f"Asked Agent {self.agent_name} a question") | |
agent_response = ollama.chat( | |
model=self.agent_name, | |
messages=[{"role": "user", "content": prompt}], | |
) | |
output_assistant_response( | |
f"Agent {self.agent_name} answered with {agent_response.message.content}") | |
return agent_response.message.content | |
def delete_agent(self): | |
ollama.delete(self.agent_name) | |
def get_type(self): | |
return self.type | |
class GeminiAgent(Agent): | |
type = "cloud" | |
def __init__(self, | |
agent_name: str, | |
base_model: str, | |
system_prompt: str, | |
create_resource_cost: int, | |
invoke_resource_cost: int, | |
create_expense_cost: int = 0, | |
invoke_expense_cost: int = 0, | |
output_expense_cost: int = 0): | |
load_dotenv() | |
self.api_key = os.getenv("GEMINI_KEY") | |
if not self.api_key: | |
raise ValueError( | |
"Google API key is required for Gemini models. Set GOOGLE_API_KEY environment variable or pass api_key parameter.") | |
# Initialize the Gemini API | |
self.client = genai.Client(api_key=self.api_key) | |
self.chat = self.client.chats.create(model=base_model) | |
# Call parent constructor after API setup | |
super().__init__(agent_name, | |
base_model, | |
system_prompt, | |
create_resource_cost, | |
invoke_resource_cost, | |
create_expense_cost, | |
invoke_expense_cost, | |
output_expense_cost) | |
def create_model(self): | |
self.messages = [] | |
def ask_agent(self, prompt): | |
response = self.chat.send_message( | |
message=prompt, | |
config=types.GenerateContentConfig( | |
system_instruction=self.system_prompt, | |
) | |
) | |
return response.text | |
def delete_agent(self): | |
self.messages = [] | |
def get_type(self): | |
return self.type | |
class GroqAgent(Agent): | |
type = "cloud" | |
def __init__( | |
self, | |
agent_name: str, | |
base_model: str, | |
system_prompt: str, | |
create_resource_cost: int, | |
invoke_resource_cost: int, | |
create_expense_cost: int = 0, | |
invoke_expense_cost: int = 0, | |
output_expense_cost: int = 0 | |
): | |
# Call the parent class constructor first | |
super().__init__(agent_name, base_model, system_prompt, | |
create_resource_cost, invoke_resource_cost, | |
create_expense_cost, invoke_expense_cost, | |
output_expense_cost) | |
# Groq-specific API client setup | |
api_key = os.getenv("GROQ_API_KEY") | |
if not api_key: | |
raise ValueError("GROQ_API_KEY environment variable not set. Please set it in your .env file or environment.") | |
self.client = Groq(api_key=api_key) | |
if self.base_model and "groq-" in self.base_model: | |
self.groq_api_model_name = self.base_model.split("groq-", 1)[1] | |
else: | |
# Fallback or error if the naming convention isn't followed. | |
# This ensures that if a non-prefixed model name is somehow passed, | |
# it might still work, or you can raise an error. | |
self.groq_api_model_name = self.base_model | |
print(f"Warning: GroqAgent base_model '{self.base_model}' does not follow 'groq-' prefix convention.") | |
def create_model(self) -> None: | |
""" | |
Create and Initialize agent. | |
For Groq, models are pre-existing on their cloud. | |
This method is called by Agent's __init__. | |
""" | |
pass | |
def ask_agent(self, prompt: str) -> str: | |
"""Ask agent a question""" | |
if not self.client: | |
raise ConnectionError("Groq client not initialized. Check API key and constructor.") | |
if not self.groq_api_model_name: | |
raise ValueError("Groq API model name not set. Check base_model configuration.") | |
messages = [ | |
{"role": "system", "content": self.system_prompt}, | |
{"role": "user", "content": prompt}, | |
] | |
try: | |
response = self.client.chat.completions.create( | |
messages=messages, | |
model=self.groq_api_model_name, # Use the derived model name for Groq API | |
) | |
result = response.choices[0].message.content | |
return result | |
except Exception as e: | |
# Handle API errors or other exceptions during the call | |
print(f"Error calling Groq API: {e}") | |
raise # Re-raise the exception or handle it as appropriate | |
def delete_agent(self) -> None: | |
"""Delete agent""" | |
pass | |
def get_type(self) -> str: # Ensure return type hint matches Agent ABC | |
"""Get agent type""" | |
return self.type | |
class LambdaAgent(Agent): | |
type = "cloud" | |
def __init__(self, | |
agent_name: str, | |
base_model: str, | |
system_prompt: str, | |
create_resource_cost: int, | |
invoke_resource_cost: int, | |
create_expense_cost: int = 0, | |
invoke_expense_cost: int = 0, | |
output_expense_cost: int = 0, | |
api_key: str = ""): | |
self.lambda_url = "https://api.lambda.ai/v1" | |
self.api_key = api_key or os.getenv("LAMBDA_API_KEY") | |
self.lambda_model = base_model.split("lambda-")[1] if base_model.startswith("lambda-") else base_model | |
if not self.api_key: | |
raise ValueError("Lambda API key must be provided or set in LAMBDA_API_KEY environment variable.") | |
self.client = client = OpenAI( | |
api_key=self.api_key, | |
base_url=self.lambda_url, | |
) | |
super().__init__(agent_name, | |
base_model, | |
system_prompt, | |
create_resource_cost, | |
invoke_resource_cost, | |
create_expense_cost, | |
invoke_expense_cost, | |
output_expense_cost) | |
def create_model(self) -> None: | |
pass # Lambda already deployed | |
def ask_agent(self, prompt: str) -> str: | |
"""Ask agent a question""" | |
try: | |
response = self.client.chat.completions.create( | |
model=self.lambda_model, | |
messages=[ | |
{"role": "system", "content": self.system_prompt}, | |
{"role": "user", "content": prompt} | |
], | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
output_assistant_response(f"Error asking agent: {e}") | |
raise | |
def delete_agent(self) -> None: | |
pass | |
def get_type(self) -> str: | |
return self.type | |
class AgentManager(): | |
budget_manager: BudgetManager = BudgetManager() | |
is_creation_enabled: bool = True | |
is_cloud_invocation_enabled: bool = True | |
is_local_invocation_enabled: bool = True | |
def __init__(self): | |
self._agents: Dict[str, Agent] = {} | |
self._agent_types = { | |
"ollama": OllamaAgent, | |
"gemini": GeminiAgent, | |
"groq": GroqAgent, | |
"lambda": LambdaAgent, | |
} | |
self._load_agents() | |
def set_creation_mode(self, status: bool): | |
self.is_creation_enabled = status | |
if status: | |
output_assistant_response("Agent creation mode is enabled.") | |
else: | |
output_assistant_response("Agent creation mode is disabled.") | |
def set_cloud_invocation_mode(self, status: bool): | |
self.is_cloud_invocation_enabled = status | |
if status: | |
output_assistant_response("Cloud invocation mode is enabled.") | |
else: | |
output_assistant_response("Cloud invocation mode is disabled.") | |
def set_local_invocation_mode(self, status: bool): | |
self.is_local_invocation_enabled = status | |
if status: | |
output_assistant_response("Local invocation mode is enabled.") | |
else: | |
output_assistant_response("Local invocation mode is disabled.") | |
def create_agent(self, agent_name: str, | |
base_model: str, system_prompt: str, | |
description: str = "", create_resource_cost: float = 0, | |
invoke_resource_cost: float = 0, | |
create_expense_cost: float = 0, | |
invoke_expense_cost: float = 0, | |
output_expense_cost: float = 0, | |
**additional_params) -> Tuple[Agent, int]: | |
if not self.is_creation_enabled: | |
raise ValueError("Agent creation mode is disabled.") | |
if agent_name in self._agents: | |
raise ValueError(f"Agent {agent_name} already exists") | |
self._agents[agent_name] = self.create_agent_class( | |
agent_name, | |
base_model, | |
system_prompt, | |
description=description, | |
create_resource_cost=create_resource_cost, | |
invoke_resource_cost=invoke_resource_cost, | |
create_expense_cost=create_expense_cost, | |
invoke_expense_cost=invoke_expense_cost, | |
output_expense_cost=output_expense_cost, | |
**additional_params # For any future parameters we might want to add | |
) | |
# save agent to file | |
self._save_agent( | |
agent_name, | |
base_model, | |
system_prompt, | |
description=description, | |
create_resource_cost=create_resource_cost, | |
invoke_resource_cost=invoke_resource_cost, | |
create_expense_cost=create_expense_cost, | |
invoke_expense_cost=invoke_expense_cost, | |
output_expense_cost=output_expense_cost, | |
**additional_params # For any future parameters we might want to add | |
) | |
return (self._agents[agent_name], | |
self.budget_manager.get_current_remaining_resource_budget(), | |
self.budget_manager.get_current_remaining_expense_budget()) | |
def validate_budget(self, | |
resource_cost: float = 0, | |
expense_cost: float = 0) -> None: | |
if not self.budget_manager.can_spend_resource(resource_cost): | |
raise ValueError(f"Do not have enough resource budget to create/use the agent. " | |
+ f"Creating/Using the agent costs {resource_cost} but only {self.budget_manager.get_current_remaining_resource_budget()} is remaining") | |
if not self.budget_manager.can_spend_expense(expense_cost): | |
raise ValueError(f"Do not have enough expense budget to create/use the agent. " | |
+ f"Creating/Using the agent costs {expense_cost} but only {self.budget_manager.get_current_remaining_expense_budget()} is remaining") | |
def create_agent_class(self, | |
agent_name: str, | |
base_model: str, | |
system_prompt: str, | |
description: str = "", | |
create_resource_cost: float = 0, | |
invoke_resource_cost: float = 0, | |
create_expense_cost: float = 0, | |
invoke_expense_cost: float = 0, | |
output_expense_cost: float = 0, | |
**additional_params) -> Agent: | |
agent_type = self._get_agent_type(base_model) | |
agent_class = self._agent_types.get(agent_type) | |
if not agent_class: | |
raise ValueError(f"Unsupported base model {base_model}") | |
created_agent = agent_class(agent_name, | |
base_model, | |
system_prompt, | |
create_resource_cost, | |
invoke_resource_cost, | |
create_expense_cost, | |
invoke_expense_cost, | |
output_expense_cost, | |
**additional_params) | |
self.validate_budget(create_resource_cost, | |
create_expense_cost) | |
self.budget_manager.add_to_resource_budget(create_resource_cost) | |
self.budget_manager.add_to_expense_budget(create_expense_cost) | |
# create agent | |
return created_agent | |
def get_agent(self, agent_name: str) -> Agent: | |
"""Get existing agent by name""" | |
if agent_name not in self._agents: | |
raise ValueError(f"Agent {agent_name} does not exists") | |
return self._agents[agent_name] | |
def list_agents(self) -> dict: | |
"""Return agent information (name, description, costs)""" | |
try: | |
if os.path.exists(MODEL_FILE_PATH): | |
with open(MODEL_FILE_PATH, "r", encoding="utf8") as f: | |
full_models = json.loads(f.read()) | |
# Create a simplified version with only the description and costs | |
simplified_agents = {} | |
for name, data in full_models.items(): | |
simplified_agents[name] = { | |
"description": data.get("description", ""), | |
"create_resource_cost": data.get("create_resource_cost", 0), | |
"invoke_resource_cost": data.get("invoke_resource_cost", 0), | |
"create_expense_cost": data.get("create_expense_cost", 0), | |
"invoke_expense_cost": data.get("invoke_expense_cost", 0), | |
"base_model": data.get("base_model", ""), | |
} | |
return simplified_agents | |
else: | |
return {} | |
except Exception as e: | |
output_assistant_response(f"Error listing agents: {e}") | |
return {} | |
def delete_agent(self, agent_name: str) -> int: | |
agent: Agent = self.get_agent(agent_name) | |
self.budget_manager.remove_from_resource_expense( | |
agent.create_resource_cost) | |
agent.delete_agent() | |
del self._agents[agent_name] | |
try: | |
if os.path.exists(MODEL_FILE_PATH): | |
with open(MODEL_FILE_PATH, "r", encoding="utf8") as f: | |
models = json.loads(f.read()) | |
del models[agent_name] | |
with open(MODEL_FILE_PATH, "w", encoding="utf8") as f: | |
f.write(json.dumps(models, indent=4)) | |
except Exception as e: | |
output_assistant_response(f"Error deleting agent: {e}") | |
return (self.budget_manager.get_current_remaining_resource_budget(), | |
self.budget_manager.get_current_remaining_expense_budget()) | |
def ask_agent(self, agent_name: str, prompt: str) -> Tuple[str, int]: | |
agent: Agent = self.get_agent(agent_name) | |
print(agent.get_type()) | |
print(agent_name) | |
print(self.is_local_invocation_enabled, | |
self.is_cloud_invocation_enabled) | |
if not self.is_local_invocation_enabled and agent.get_type() == "local": | |
raise ValueError("Local invocation mode is disabled.") | |
if not self.is_cloud_invocation_enabled and agent.get_type() == "cloud": | |
raise ValueError("Cloud invocation mode is disabled.") | |
n_tokens = len(prompt.split())/1000000 | |
self.validate_budget(agent.invoke_resource_cost, | |
agent.invoke_expense_cost*n_tokens) | |
self.budget_manager.add_to_expense_budget( | |
agent.invoke_expense_cost*n_tokens) | |
response = agent.ask_agent(prompt) | |
n_tokens = len(response.split())/1000000 | |
self.budget_manager.add_to_expense_budget( | |
agent.output_expense_cost*n_tokens) | |
return (response, | |
self.budget_manager.get_current_remaining_resource_budget(), | |
self.budget_manager.get_current_remaining_expense_budget()) | |
def _save_agent(self, | |
agent_name: str, | |
base_model: str, | |
system_prompt: str, | |
description: str = "", | |
create_resource_cost: float = 0, | |
invoke_resource_cost: float = 0, | |
create_expense_cost: float = 0, | |
invoke_expense_cost: float = 0, | |
output_expense_cost: float = 0, | |
**additional_params) -> None: | |
"""Save a single agent to the models.json file""" | |
try: | |
# Ensure the directory exists | |
os.makedirs(MODEL_PATH, exist_ok=True) | |
# Read existing models file or create empty dict if it doesn't exist | |
try: | |
with open(MODEL_FILE_PATH, "r", encoding="utf8") as f: | |
models = json.loads(f.read()) | |
except (FileNotFoundError, json.JSONDecodeError): | |
models = {} | |
# Update the models dict with the new agent | |
models[agent_name] = { | |
"base_model": base_model, | |
"description": description, | |
"system_prompt": system_prompt, | |
"create_resource_cost": create_resource_cost, | |
"invoke_resource_cost": invoke_resource_cost, | |
"create_expense_cost": create_expense_cost, | |
"invoke_expense_cost": invoke_expense_cost, | |
"output_expense_cost": output_expense_cost, | |
} | |
# Add any additional parameters that were passed | |
for key, value in additional_params.items(): | |
models[agent_name][key] = value | |
# Write the updated models back to the file | |
with open(MODEL_FILE_PATH, "w", encoding="utf8") as f: | |
f.write(json.dumps(models, indent=4)) | |
except Exception as e: | |
output_assistant_response(f"Error saving agent {agent_name}: {e}") | |
def _get_agent_type(self, base_model) -> str: | |
if base_model == "llama3.2": | |
return "ollama" | |
elif base_model == "mistral": | |
return "ollama" | |
elif base_model == "deepseek-r1": | |
return "ollama" | |
elif "gemini" in base_model: | |
return "gemini" | |
elif "groq" in base_model: | |
return "groq" | |
elif base_model.startswith("lambda-"): | |
return "lambda" | |
else: | |
return "unknown" | |
def _load_agents(self) -> None: | |
"""Load agent configurations from disk""" | |
try: | |
if not os.path.exists(MODEL_FILE_PATH): | |
return | |
with open(MODEL_FILE_PATH, "r", encoding="utf8") as f: | |
models = json.loads(f.read()) | |
for name, data in models.items(): | |
if name in self._agents: | |
continue | |
base_model = data["base_model"] | |
system_prompt = data["system_prompt"] | |
create_resource_cost = data.get("create_resource_cost", 0) | |
invoke_resource_cost = data.get("invoke_resource_cost", 0) | |
create_expense_cost = data.get("create_expense_cost", 0) | |
invoke_expense_cost = data.get("invoke_expense_cost", 0) | |
output_expense_cost = data.get("output_expense_cost", 0) | |
model_type = self._get_agent_type(base_model) | |
manager_class = self._agent_types.get(model_type) | |
if manager_class: | |
# Create the agent with the appropriate manager class | |
self._agents[name] = self.create_agent_class( | |
name, | |
base_model, | |
system_prompt, | |
description=data.get("description", ""), | |
create_resource_cost=create_resource_cost, | |
invoke_resource_cost=invoke_resource_cost, | |
create_expense_cost=create_expense_cost, | |
invoke_expense_cost=invoke_expense_cost, | |
output_expense_cost=output_expense_cost, | |
**data.get("additional_params", {}) | |
) | |
self._agents[name] = manager_class( | |
name, | |
base_model, | |
system_prompt, | |
create_resource_cost, | |
invoke_resource_cost, | |
create_expense_cost, | |
invoke_expense_cost, | |
output_expense_cost | |
) | |
except Exception as e: | |
output_assistant_response(f"Error loading agents: {e}") | |