Harshil Patel commited on
Commit
e5d4b0f
·
1 Parent(s): 8875451

Add gemini agent

Browse files
src/agent_manager.py CHANGED
@@ -5,7 +5,11 @@ import json
5
  import ollama
6
  from src.utils.singleton import singleton
7
  from src.utils.streamlit_interface import output_assistant_response
8
-
 
 
 
 
9
  class Agent(ABC):
10
 
11
  def __init__(self, agent_name: str, base_model: str, system_prompt: str, creation_cost: str, invoke_cost: str):
@@ -59,13 +63,42 @@ class OllamaAgent(Agent):
59
  def delete_agent(self):
60
  ollama.delete(self.agent_name)
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  @singleton
63
  class AgentManager():
64
 
65
  def __init__(self):
66
  self._agents = {}
67
  self._agent_types ={
68
- "ollama": OllamaAgent
 
69
  }
70
 
71
  self._load_agents()
@@ -185,6 +218,8 @@ class AgentManager():
185
  return "ollama"
186
  elif base_model == "mistral":
187
  return "ollama"
 
 
188
  else:
189
  return "unknown"
190
 
 
5
  import ollama
6
  from src.utils.singleton import singleton
7
  from src.utils.streamlit_interface import output_assistant_response
8
+ from google import genai
9
+ from google.genai import types
10
+ from google.genai.types import *
11
+ import os
12
+ from dotenv import load_dotenv
13
  class Agent(ABC):
14
 
15
  def __init__(self, agent_name: str, base_model: str, system_prompt: str, creation_cost: str, invoke_cost: str):
 
63
  def delete_agent(self):
64
  ollama.delete(self.agent_name)
65
 
66
+ class GeminiAgent(Agent):
67
+ def __init__(self, agent_name: str, base_model: str, system_prompt: str, creation_cost: str, invoke_cost: str):
68
+ load_dotenv()
69
+ self.api_key = os.getenv("GEMINI_KEY")
70
+ if not self.api_key:
71
+ raise ValueError("Google API key is required for Gemini models. Set GOOGLE_API_KEY environment variable or pass api_key parameter.")
72
+
73
+ # Initialize the Gemini API
74
+ self.client = genai.Client(api_key=self.api_key)
75
+
76
+ # Call parent constructor after API setup
77
+ super().__init__(agent_name, base_model, system_prompt, creation_cost, invoke_cost)
78
+
79
+ def create_model(self):
80
+ self.messages = []
81
+
82
+ def ask_agent(self, prompt):
83
+ response = self.client.models.generate_content(
84
+ model=self.base_model,
85
+ contents=prompt,
86
+ config=types.GenerateContentConfig(
87
+ system_instruction=self.system_prompt,
88
+ )
89
+ )
90
+ return response.text
91
+
92
+ def delete_agent(self):
93
+ self.messages = []
94
  @singleton
95
  class AgentManager():
96
 
97
  def __init__(self):
98
  self._agents = {}
99
  self._agent_types ={
100
+ "ollama": OllamaAgent,
101
+ "gemini": GeminiAgent
102
  }
103
 
104
  self._load_agents()
 
218
  return "ollama"
219
  elif base_model == "mistral":
220
  return "ollama"
221
+ elif "gemini" in base_model:
222
+ return "gemini"
223
  else:
224
  return "unknown"
225
 
tools/agent_creater_tool.py CHANGED
@@ -18,7 +18,7 @@ class AgentCreator():
18
  },
19
  "base_model": {
20
  "type": "string",
21
- "description": "A base model from which the new agent mode is to be created. Available models are: llama3.2, mistral"
22
  },
23
  "system_prompt": {
24
  "type": "string",
@@ -43,6 +43,11 @@ class AgentCreator():
43
  "description": "7 Billion parameter model",
44
  "create_cost": 20,
45
  "invoke_cost": 50,
 
 
 
 
 
46
  }
47
  }
48
  }
 
18
  },
19
  "base_model": {
20
  "type": "string",
21
+ "description": "A base model from which the new agent mode is to be created. Available models are: llama3.2, mistral, gemini-2.0-flash"
22
  },
23
  "system_prompt": {
24
  "type": "string",
 
43
  "description": "7 Billion parameter model",
44
  "create_cost": 20,
45
  "invoke_cost": 50,
46
+ },
47
+ "gemini-2.0-flash": {
48
+ "description": "40 Billion parameter model",
49
+ "create_cost": 30,
50
+ "invoke_cost": 60,
51
  }
52
  }
53
  }