magnustragardh commited on
Commit
b3d6cbb
·
1 Parent(s): bafc8df

Add some rate limiting, and image support.

Browse files
Files changed (1) hide show
  1. app.py +52 -10
app.py CHANGED
@@ -2,14 +2,17 @@ import asyncio
2
  import os
3
  from pathlib import Path
4
  import gradio as gr
 
5
  import requests
6
  import pandas as pd
 
7
  from llama_index.llms.gemini import Gemini
8
- from llama_index.core.agent.workflow import ReActAgent
9
  from llama_index.core.tools import FunctionTool
10
  from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
11
  from llama_index.tools.wikipedia import WikipediaToolSpec
12
  from dotenv import load_dotenv
 
13
 
14
  try:
15
  import mlflow
@@ -28,7 +31,7 @@ GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']
28
 
29
  # --- Basic Agent Definition ---
30
  class BasicAgent:
31
- def __init__(self):
32
  search_tool = FunctionTool.from_defaults(DuckDuckGoSearchToolSpec().duckduckgo_full_search)
33
  wikipedia_load_tool = FunctionTool.from_defaults(WikipediaToolSpec().load_data)
34
  wikipedia_search_tool = FunctionTool.from_defaults(WikipediaToolSpec().search_data)
@@ -39,10 +42,17 @@ class BasicAgent:
39
  # Modify the react prompt.
40
  self._agent.update_prompts({"react_header": SYSTEM_PROMPT})
41
  print("BasicAgent initialized.")
42
-
43
- async def __call__(self, question: str) -> str:
44
- print(f"Agent received question (first 50 chars): {question[:50]}...")
45
- agent_output = await self._agent.run(user_msg=question)
 
 
 
 
 
 
 
46
  print(f"Agent returning answer: {agent_output}")
47
  response_parts = str(agent_output).split('FINAL ANSWER: ')
48
  if len(response_parts) > 1:
@@ -65,14 +75,46 @@ def fetch_questions(api_url: str = DEFAULT_API_URL):
65
  return questions_data
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  async def answer_question(agent, item, answers_payload, results_log):
69
  task_id = item.get("task_id")
70
- question_text = item.get("question")
71
- if not task_id or question_text is None:
72
- print(f"Skipping item with missing task_id or question: {item}")
 
 
 
 
73
  return
74
  try:
75
- submitted_answer = await agent(question_text)
76
 
77
  # Avoid hitting the Google rate limits.
78
  await asyncio.sleep(60)
 
2
  import os
3
  from pathlib import Path
4
  import gradio as gr
5
+ import mimetypes
6
  import requests
7
  import pandas as pd
8
+ from llama_index.core.llms import ChatMessage, TextBlock, ImageBlock, AudioBlock
9
  from llama_index.llms.gemini import Gemini
10
+ from llama_index.core.agent.workflow import ReActAgent, AgentOutput
11
  from llama_index.core.tools import FunctionTool
12
  from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
13
  from llama_index.tools.wikipedia import WikipediaToolSpec
14
  from dotenv import load_dotenv
15
+ from pydantic import ValidationError
16
 
17
  try:
18
  import mlflow
 
31
 
32
  # --- Basic Agent Definition ---
33
  class BasicAgent:
34
+ def __init__(self, max_calls_per_minute=15):
35
  search_tool = FunctionTool.from_defaults(DuckDuckGoSearchToolSpec().duckduckgo_full_search)
36
  wikipedia_load_tool = FunctionTool.from_defaults(WikipediaToolSpec().load_data)
37
  wikipedia_search_tool = FunctionTool.from_defaults(WikipediaToolSpec().search_data)
 
42
  # Modify the react prompt.
43
  self._agent.update_prompts({"react_header": SYSTEM_PROMPT})
44
  print("BasicAgent initialized.")
45
+ self._min_call_interval = 1/max_calls_per_minute
46
+
47
+ async def __call__(self, question: ChatMessage) -> str:
48
+ question.blocks[0].text
49
+ print(f"Agent received question (first 50 chars): {question.blocks[0].text[:50]}...")
50
+ # Here, we need to rate limit
51
+ handler = self._agent.run(user_msg=question)
52
+ async for event in handler.stream_events():
53
+ if isinstance(event, AgentOutput):
54
+ await asyncio.sleep(self._min_call_interval)
55
+ agent_output = await handler
56
  print(f"Agent returning answer: {agent_output}")
57
  response_parts = str(agent_output).split('FINAL ANSWER: ')
58
  if len(response_parts) > 1:
 
75
  return questions_data
76
 
77
 
78
+ def get_media_type(filename: str):
79
+ media_type = mimetypes.guess_type(filename)[0]
80
+ if media_type is not None:
81
+ return media_type.split('/')[0]
82
+
83
+
84
+ def get_media_content(item):
85
+ if item.get('file_name'):
86
+ file_response = requests.get(f"{DEFAULT_API_URL}/files/{item.get('task_id')}")
87
+ if file_response:
88
+ media_type = get_media_type(item.get('file_name'))
89
+ if media_type == 'image':
90
+ return ImageBlock(image=file_response.content)
91
+ # Audio currently not supported
92
+ #elif media_type == 'audio':
93
+ # return AudioBlock(audio=file_response.content)
94
+
95
+
96
+ def create_question_message(item):
97
+ question_text = item.get("question")
98
+ msg_blocks = [TextBlock(text=question_text)]
99
+ media_block = get_media_content(item)
100
+ if media_block is not None:
101
+ msg_blocks.append(media_block)
102
+ question_message = ChatMessage(role="user", blocks=msg_blocks)
103
+ return question_message
104
+
105
+
106
  async def answer_question(agent, item, answers_payload, results_log):
107
  task_id = item.get("task_id")
108
+ try:
109
+ question_message = create_question_message(item)
110
+ except ValidationError:
111
+ print(f"Skipping item for which the question could not be processed: {item}")
112
+ return
113
+ if not task_id:
114
+ print(f"Skipping item with missing task_id: {item}")
115
  return
116
  try:
117
+ submitted_answer = await agent.run(question_message)
118
 
119
  # Avoid hitting the Google rate limits.
120
  await asyncio.sleep(60)