helloparthshah commited on
Commit
7c06e97
·
1 Parent(s): a3a158e

Updated to use pro and fixed retry and edit

Browse files
Files changed (3) hide show
  1. main.py +53 -14
  2. mainV2.py +0 -61
  3. src/CEO.py +0 -134
main.py CHANGED
@@ -1,24 +1,63 @@
1
  from google.genai import types
2
- from src.CEO import GeminiManager
3
  from src.tool_loader import ToolLoader
4
-
 
5
 
6
  if __name__ == "__main__":
7
  # Define the tool metadata for orchestration.
8
  # Load the tools using the ToolLoader class.
9
  tool_loader = ToolLoader()
10
 
11
- model_manager = GeminiManager(toolsLoader=tool_loader, gemini_model="gemini-2.0-flash")
 
 
 
 
 
 
 
 
12
 
13
- task_prompt = (
14
- "What is the peak price of trump coin in the last 30 days? "
15
- "Please provide the price in USD. "
16
- )
 
 
 
 
 
 
17
 
18
- # Request a CEO response with the prompt.
19
- # user_prompt_content = types.Content(
20
- # role='user',
21
- # parts=[types.Part.from_text(text=task_prompt)],
22
- # )
23
- # response = model_manager.request([user_prompt_content])
24
- response = model_manager.start_conversation()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from google.genai import types
2
+ from src.manager import GeminiManager
3
  from src.tool_loader import ToolLoader
4
+ import gradio as gr
5
+ import time
6
 
7
  if __name__ == "__main__":
8
  # Define the tool metadata for orchestration.
9
  # Load the tools using the ToolLoader class.
10
  tool_loader = ToolLoader()
11
 
12
+ model_manager = GeminiManager(toolsLoader=tool_loader, gemini_model="gemini-2.5-pro-preview-03-25")
13
+
14
+ def user_message(msg: str, history: list) -> tuple[str, list]:
15
+ """Adds user message to chat history"""
16
+ history.append(gr.ChatMessage(role="user", content=msg))
17
+ return "", history
18
+
19
+ def handle_undo(history, undo_data: gr.UndoData):
20
+ return history[:undo_data.index], history[undo_data.index]['content']
21
 
22
+ def handle_retry(history, retry_data: gr.RetryData):
23
+ new_history = history[:retry_data.index+1]
24
+ # yield new_history, gr.update(interactive=False,)
25
+ yield from model_manager.run(new_history)
26
+
27
+ def handle_edit(history, edit_data: gr.EditData):
28
+ new_history = history[:edit_data.index+1]
29
+ new_history[-1]['content'] = edit_data.value
30
+ # yield new_history, gr.update(interactive=False,)
31
+ yield from model_manager.run(new_history)
32
 
33
+ with gr.Blocks(fill_width=True, fill_height=True) as demo:
34
+ gr.Markdown("# Hashiru AI")
35
+
36
+ chatbot = gr.Chatbot(
37
+ avatar_images=("HASHIRU_2.png", "HASHIRU.png"),
38
+ type="messages",
39
+ show_copy_button=True,
40
+ editable="user",
41
+ scale=1
42
+ )
43
+ input_box = gr.Textbox(label="Chat Message", scale=0, interactive=True, submit_btn=True)
44
+
45
+ chatbot.undo(handle_undo, chatbot, [chatbot, input_box])
46
+ chatbot.retry(handle_retry, chatbot, [chatbot, input_box])
47
+ chatbot.edit(handle_edit, chatbot, [chatbot, input_box])
48
+
49
+ input_box.submit(
50
+ user_message, # Add user message to chat
51
+ inputs=[input_box, chatbot],
52
+ outputs=[input_box, chatbot],
53
+ queue=False,
54
+ ).then(
55
+ model_manager.run, # Generate and stream response
56
+ inputs=chatbot,
57
+ outputs=[chatbot, input_box],
58
+ queue=True,
59
+ show_progress="full",
60
+ trigger_mode="always_last"
61
+ )
62
+
63
+ demo.launch(share=True)
mainV2.py DELETED
@@ -1,61 +0,0 @@
1
- from google.genai import types
2
- from src.manager import GeminiManager
3
- from src.tool_loader import ToolLoader
4
- import gradio as gr
5
- import time
6
-
7
- if __name__ == "__main__":
8
- # Define the tool metadata for orchestration.
9
- # Load the tools using the ToolLoader class.
10
- tool_loader = ToolLoader()
11
-
12
- model_manager = GeminiManager(toolsLoader=tool_loader, gemini_model="gemini-2.0-flash")
13
-
14
- def user_message(msg: str, history: list) -> tuple[str, list]:
15
- """Adds user message to chat history"""
16
- history.append(gr.ChatMessage(role="user", content=msg))
17
- return "", history
18
-
19
- def handle_undo(history, undo_data: gr.UndoData):
20
- return history[:undo_data.index], history[undo_data.index]['content']
21
-
22
- def handle_retry(history, retry_data: gr.RetryData):
23
- new_history = history[:retry_data.index]
24
- yield from model_manager.run(new_history)
25
-
26
- def handle_edit(history, edit_data: gr.EditData):
27
- new_history = history[:edit_data.index]
28
- new_history[-1]['content'] = edit_data.value
29
- return new_history
30
-
31
- with gr.Blocks(fill_width=True, fill_height=True) as demo:
32
- gr.Markdown("# Hashiru AI")
33
-
34
- chatbot = gr.Chatbot(
35
- avatar_images=("HASHIRU_2.png", "HASHIRU.png"),
36
- type="messages",
37
- show_copy_button=True,
38
- editable="user",
39
- scale=1
40
- )
41
- input_box = gr.Textbox(max_lines=5, label="Chat Message", scale=0)
42
-
43
- chatbot.undo(handle_undo, chatbot, [chatbot, input_box])
44
- chatbot.retry(handle_retry, chatbot, chatbot)
45
- chatbot.edit(handle_edit, chatbot, chatbot)
46
-
47
- input_box.submit(
48
- user_message, # Add user message to chat
49
- inputs=[input_box, chatbot],
50
- outputs=[input_box, chatbot],
51
- queue=False,
52
- ).then(
53
- model_manager.run, # Generate and stream response
54
- inputs=chatbot,
55
- outputs=[chatbot, input_box],
56
- queue=True,
57
- show_progress="full",
58
- trigger_mode="always_last"
59
- )
60
-
61
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/CEO.py DELETED
@@ -1,134 +0,0 @@
1
- from google import genai
2
- from google.genai import types
3
- import os
4
- from dotenv import load_dotenv
5
- import sys
6
- from src.tool_loader import ToolLoader
7
- from src.utils.suppress_outputs import suppress_output
8
- import logging
9
-
10
- from src.utils.streamlit_interface import get_user_message, output_assistant_response
11
-
12
- logger = logging.getLogger(__name__)
13
- handler = logging.StreamHandler(sys.stdout)
14
- handler.setLevel(logging.INFO)
15
- logger.addHandler(handler)
16
-
17
- class GeminiManager:
18
- def __init__(self, toolsLoader: ToolLoader, system_prompt_file="./models/system3.prompt", gemini_model="gemini-2.5-pro-exp-03-25"):
19
- load_dotenv()
20
- self.API_KEY = os.getenv("GEMINI_KEY")
21
- self.client = genai.Client(api_key=self.API_KEY)
22
- self.toolsLoader: ToolLoader = toolsLoader
23
- self.toolsLoader.load_tools()
24
- self.model_name = gemini_model
25
- with open(system_prompt_file, 'r', encoding="utf8") as f:
26
- self.system_prompt = f.read()
27
- self.messages = []
28
-
29
- def generate_response(self, messages):
30
- return self.client.models.generate_content(
31
- #model='gemini-2.5-pro-preview-03-25',
32
- model=self.model_name,
33
- #model='gemini-2.5-pro-exp-03-25',
34
- #model='gemini-2.0-flash',
35
- contents=messages,
36
- config=types.GenerateContentConfig(
37
- system_instruction=self.system_prompt,
38
- temperature=0.2,
39
- tools=self.toolsLoader.getTools(),
40
- ),
41
- )
42
-
43
- def handle_tool_calls(self, response):
44
- parts = []
45
- for function_call in response.function_calls:
46
- toolResponse = None
47
- logger.info(f"Function Name: {function_call.name}, Arguments: {function_call.args}")
48
- try:
49
- toolResponse = self.toolsLoader.runTool(function_call.name, function_call.args)
50
- except Exception as e:
51
- logger.warning(f"Error running tool: {e}")
52
- toolResponse = {
53
- "status": "error",
54
- "message": f"Tool {function_call.name} failed to run.",
55
- "output": str(e),
56
- }
57
- logger.debug(f"Tool Response: {toolResponse}")
58
- tool_content = types.Part.from_function_response(
59
- name=function_call.name,
60
- response = {"result":toolResponse})
61
- try:
62
- self.toolsLoader.load_tools()
63
- except Exception as e:
64
- logger.info(f"Error loading tools: {e}. Deleting the tool.")
65
- # delete the created tool
66
- self.toolsLoader.delete_tool(toolResponse['output']['tool_name'], toolResponse['output']['tool_file_path'])
67
- tool_content = types.Part.from_function_response(
68
- name=function_call.name,
69
- response={"result":f"{function_call.name} with {function_call.args} doesn't follow the required format, please read the other tool implementations for reference." + str(e)})
70
- parts.append(tool_content)
71
- return types.Content(
72
- role='model' if self.model_name == "gemini-2.5-pro-exp-03-25" else 'tool',
73
- parts=parts
74
- )
75
-
76
- def run(self, messages):
77
- try:
78
- response = suppress_output(self.generate_response)(messages)
79
- except Exception as e:
80
- logger.debug(f"Error generating response: {e}")
81
- shouldRetry = get_user_message("An error occurred. Do you want to retry? (y/n): ")
82
- if shouldRetry and shouldRetry.lower() == "y":
83
- return self.run(messages)
84
- else:
85
- output_assistant_response("Ending the conversation.")
86
- return messages
87
-
88
- logger.debug(f"Response: {response}")
89
-
90
- if (not response.text and not response.function_calls):
91
- output_assistant_response("No response from the model.")
92
-
93
- # Attach the llm response to the messages
94
- if response.text is not None:
95
- output_assistant_response("CEO: " + response.text)
96
- # print("CEO:", response.text)
97
- assistant_content = types.Content(
98
- role='model' if self.model_name == "gemini-2.5-pro-exp-03-25" else 'assistant',
99
- parts=[types.Part.from_text(text=response.text)],
100
- )
101
- messages.append(assistant_content)
102
-
103
- # Attach the function call response to the messages
104
- if response.candidates[0].content and response.candidates[0].content.parts:
105
- messages.append(response.candidates[0].content)
106
-
107
- # Invoke the function calls if any and attach the response to the messages
108
- if response.function_calls:
109
- messages.append(self.handle_tool_calls(response))
110
- shouldContinue = get_user_message("Should I continue? (y/n): ")
111
- if shouldContinue.lower() == "y":
112
- return self.run(messages)
113
- else:
114
- output_assistant_response("Ending the conversation.")
115
- return messages
116
- else:
117
- logger.debug("No tool calls found in the response.")
118
- # Start the loop again
119
- return self.start_conversation(messages)
120
-
121
- def start_conversation(self, messages=[]):
122
- question = get_user_message("User: ")
123
- # question = input("User: ")
124
- if question and ("exit" in question.lower() or "quit" in question.lower()):
125
- output_assistant_response("Ending the conversation.")
126
- return messages
127
- user_content = types.Content(
128
- role='user',
129
- parts=[types.Part.from_text(text=question)],
130
- )
131
- messages.append(user_content)
132
-
133
- # Start the conversation loop
134
- return self.run(messages)