CultriX commited on
Commit
2cb717c
·
verified ·
1 Parent(s): f5ce511

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +155 -104
run.py CHANGED
@@ -38,42 +38,40 @@ append_answer_lock = threading.Lock()
38
 
39
 
40
  class StreamingHandler(logging.Handler):
41
- """Custom logging handler that captures agent logs and sends them to callbacks."""
42
  def __init__(self):
43
  super().__init__()
44
  self.callbacks = []
 
45
 
46
  def add_callback(self, callback):
47
  self.callbacks.append(callback)
48
 
49
  def emit(self, record):
50
  msg = self.format(record)
51
- # Check if the message is actually different or non-empty after stripping
52
- # to avoid sending redundant empty strings, though `highlight_text` in app.py handles empty.
53
- if msg.strip():
54
- for callback in self.callbacks:
55
- callback(msg + '\n') # Add newline to ensure distinct lines are processed by app.py's splitter
56
 
57
 
58
- class StreamingCapture(StringIO):
59
- """Captures stdout/stderr and sends content to callbacks in real-time."""
60
  def __init__(self):
61
- super().__init__()
62
  self.callbacks = []
63
 
64
  def add_callback(self, callback):
65
  self.callbacks.append(callback)
66
 
67
- def write(self, s):
68
- # Pass the raw string 's' directly to callbacks immediately
69
- if s: # Only send if there's actual content
70
  for callback in self.callbacks:
71
- callback(s)
72
- super().write(s) # Still write to the underlying StringIO buffer
73
-
74
-
75
  def flush(self):
76
- super().flush()
77
 
78
 
79
  def create_agent(
@@ -92,10 +90,7 @@ def create_agent(
92
 
93
  if hf_token:
94
  print("[DEBUG] Logging into HuggingFace")
95
- try:
96
- login(hf_token)
97
- except Exception as e:
98
- print(f"[ERROR] Failed to log into HuggingFace: {e}")
99
 
100
  model_params = {
101
  "model_id": model_id,
@@ -106,19 +101,10 @@ def create_agent(
106
  if model_id == "gpt-4o-mini":
107
  model_params["reasoning_effort"] = "high"
108
 
109
- # Determine which API key to use based on the model_id
110
- if "openai" in model_id.lower() and openai_api_key:
111
- print("[DEBUG] Using OpenAI API key for OpenAI model")
112
- model_params["api_key"] = openai_api_key
113
- elif custom_api_endpoint and custom_api_key:
114
  print("[DEBUG] Using custom API endpoint:", custom_api_endpoint)
115
  model_params["base_url"] = custom_api_endpoint
116
  model_params["api_key"] = custom_api_key
117
- elif api_endpoint and openai_api_key: # Fallback to default OpenAI if custom not specified
118
- print("[DEBUG] Using default API endpoint:", api_endpoint)
119
- model_params["base_url"] = api_endpoint
120
- model_params["api_key"] = openai_api_key
121
- # It's important that if an API key is missing for the chosen model, it fails here or upstream.
122
 
123
  model = LiteLLMModel(**model_params)
124
  print("[DEBUG] Model initialized")
@@ -133,36 +119,23 @@ def create_agent(
133
  "headers": {"User-Agent": user_agent},
134
  "timeout": 300,
135
  },
136
- "serpapi_key": serpapi_key, # This will be used by ArchiveSearchTool if SerpAPI is enabled
137
  }
138
 
139
  os.makedirs(f"./{browser_config['downloads_folder']}", exist_ok=True)
140
  browser = SimpleTextBrowser(**browser_config)
141
  print("[DEBUG] Browser initialized")
142
 
143
- search_tool = None
144
  if search_provider == "searxng":
145
- print("[DEBUG] Using DuckDuckGoSearchTool (acting as a generic web search) for SearxNG context.")
146
  search_tool = DuckDuckGoSearchTool()
147
  if custom_search_url:
148
- # Note: As mentioned before, DuckDuckGoSearchTool doesn't natively use a custom base_url
149
- # for a completely different search engine like SearxNG. This line will likely have no effect.
150
- # For true SearxNG integration, you'd need a custom tool or a modified DuckDuckGoSearchTool
151
- # that knows how to query SearxNG instances.
152
- print(f"[WARNING] DuckDuckGoSearchTool does not directly support 'custom_search_url' for SearxNG. Consider a dedicated SearxNG tool.")
153
- # search_tool.base_url = custom_search_url # This line is often not effective for DDCSTool
154
- elif search_provider == "serper":
155
- print("[DEBUG] Using DuckDuckGoSearchTool (acting as a generic web search) for Serper context.")
156
- search_tool = DuckDuckGoSearchTool() # You would need a separate SerperTool for direct Serper API calls.
157
- if search_api_key:
158
- print("[DEBUG] Serper API Key provided. Ensure your search tool (if custom) uses it.")
159
- # If you had a dedicated SerperTool, you'd pass search_api_key to it.
160
- # e.g., search_tool = SerperTool(api_key=search_api_key)
161
  else:
162
- print("[DEBUG] No specific search provider selected, or provider not directly supported. Defaulting to DuckDuckGoSearchTool.")
163
  search_tool = DuckDuckGoSearchTool()
164
 
165
-
166
  WEB_TOOLS = [
167
  search_tool,
168
  VisitTool(browser),
@@ -170,15 +143,15 @@ def create_agent(
170
  PageDownTool(browser),
171
  FinderTool(browser),
172
  FindNextTool(browser),
173
- ArchiveSearchTool(browser), # This tool specifically uses serpapi_key from browser_config
174
  TextInspectorTool(model, text_limit),
175
  ]
176
 
177
  text_webbrowser_agent = ToolCallingAgent(
178
  model=model,
179
- tools=[tool for tool in WEB_TOOLS if tool is not None], # Filter out None if search_tool was not set
180
  max_steps=20,
181
- verbosity_level=3, # Keep this high for detailed output
182
  planning_interval=4,
183
  name="search_agent",
184
  description="A team member that will search the internet to answer your question.",
@@ -193,7 +166,7 @@ Additionally, if after some searching you find out that you need more informatio
193
  model=model,
194
  tools=[visualizer, TextInspectorTool(model, text_limit)],
195
  max_steps=16,
196
- verbosity_level=3, # Keep this high for detailed output
197
  additional_authorized_imports=AUTHORIZED_IMPORTS,
198
  planning_interval=4,
199
  managed_agents=[text_webbrowser_agent],
@@ -215,49 +188,35 @@ def run_agent_with_streaming(agent, question, stream_callback=None):
215
  root_logger = logging.getLogger()
216
  smolagents_logger = logging.getLogger('smolagents')
217
 
218
- # Store original handlers and levels
219
- original_root_handlers = root_logger.handlers[:]
220
- original_smolagents_handlers = smolagents_logger.handlers[:]
221
- original_root_level = root_logger.level
222
- original_smolagents_level = smolagents_logger.level
223
-
224
- # Store original stdout/stderr
225
- original_stdout = sys.stdout
226
- original_stderr = sys.stderr
227
 
228
- stdout_capture = StreamingCapture()
229
- stderr_capture = StreamingCapture()
230
-
231
- if stream_callback:
232
- stdout_capture.add_callback(stream_callback)
233
- stderr_capture.add_callback(stream_callback)
234
-
235
  try:
236
  # Configure logging to capture everything
237
- # Set logging levels very low to capture all verbose output
238
  root_logger.setLevel(logging.DEBUG)
239
- for handler in root_logger.handlers: # Remove existing handlers to avoid duplicate output
240
- root_logger.removeHandler(handler)
241
  root_logger.addHandler(log_handler)
242
-
243
  smolagents_logger.setLevel(logging.DEBUG)
244
- for handler in smolagents_logger.handlers: # Remove existing handlers
245
- smolagents_logger.removeHandler(handler)
246
  smolagents_logger.addHandler(log_handler)
247
 
248
- # Redirect stdout/stderr
249
- sys.stdout = stdout_capture
250
- sys.stderr = stderr_capture
251
-
252
- if stream_callback:
253
- stream_callback(f"[STARTING] Running agent with question: {question}\n")
254
-
255
- answer = agent.run(question)
256
 
257
  if stream_callback:
258
- stream_callback(f"[COMPLETED] {answer}\n")
 
 
 
 
 
 
 
259
 
260
- return answer
 
 
 
261
 
262
  except Exception as e:
263
  error_msg = f"[ERROR] Exception occurred: {str(e)}\n"
@@ -266,18 +225,111 @@ def run_agent_with_streaming(agent, question, stream_callback=None):
266
  raise
267
  finally:
268
  # Restore original logging configuration
269
- root_logger.handlers = original_root_handlers
270
- root_logger.setLevel(original_root_level)
271
- smolagents_logger.handlers = original_smolagents_handlers
272
- smolagents_logger.setLevel(original_smolagents_level)
273
-
274
- # Restore original stdout/stderr
275
- sys.stdout = original_stdout
276
- sys.stderr = original_stderr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- # Ensure any remaining buffered output is flushed (especially important for stdout/stderr)
279
- stdout_capture.flush()
280
- stderr_capture.flush()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
 
283
  def main():
@@ -290,20 +342,19 @@ def main():
290
  parser.add_argument("--model-id", type=str, default="gpt-4o-mini")
291
  parser.add_argument("--hf-token", type=str, default=os.getenv("HF_TOKEN"))
292
  parser.add_argument("--serpapi-key", type=str, default=os.getenv("SERPAPI_API_KEY"))
293
- parser.add_argument("--openai-api-key", type=str, default=os.getenv("OPENAI_API_KEY")) # Added
294
- parser.add_argument("--api-endpoint", type=str, default=os.getenv("API_ENDPOINT", "https://api.openai.com/v1")) # Added
295
  parser.add_argument("--custom-api-endpoint", type=str, default=None)
296
  parser.add_argument("--custom-api-key", type=str, default=None)
297
- parser.add_argument("--search-provider", type=str, default="searxng") # Changed default to searxng for consistency
298
  parser.add_argument("--search-api-key", type=str, default=None)
299
- parser.add_argument("--custom-search-url", type=str, default="https://search.endorisk.nl/search") # Changed default for consistency
300
  args = parser.parse_args()
301
 
302
  print("[DEBUG] CLI arguments parsed:", args)
303
 
304
  if args.gradio:
305
- print("Please run `app.py` directly to launch the Gradio interface.")
306
- return
 
307
  else:
308
  # CLI mode
309
  if not args.question:
@@ -313,9 +364,9 @@ def main():
313
  agent = create_agent(
314
  model_id=args.model_id,
315
  hf_token=args.hf_token,
316
- openai_api_key=args.openai_api_key,
317
  serpapi_key=args.serpapi_key,
318
- api_endpoint=args.api_endpoint,
319
  custom_api_endpoint=args.custom_api_endpoint,
320
  custom_api_key=args.custom_api_key,
321
  search_provider=args.search_provider,
 
38
 
39
 
40
  class StreamingHandler(logging.Handler):
41
+ """Custom logging handler that captures agent logs"""
42
  def __init__(self):
43
  super().__init__()
44
  self.callbacks = []
45
+ self.buffer = []
46
 
47
  def add_callback(self, callback):
48
  self.callbacks.append(callback)
49
 
50
  def emit(self, record):
51
  msg = self.format(record)
52
+ self.buffer.append(msg + '\n')
53
+ for callback in self.callbacks:
54
+ callback(msg + '\n')
 
 
55
 
56
 
57
+ class StreamingCapture:
58
+ """Captures stdout/stderr and yields content in real-time"""
59
  def __init__(self):
60
+ self.content = []
61
  self.callbacks = []
62
 
63
  def add_callback(self, callback):
64
  self.callbacks.append(callback)
65
 
66
+ def write(self, text):
67
+ if text.strip():
68
+ self.content.append(text)
69
  for callback in self.callbacks:
70
+ callback(text)
71
+ return len(text)
72
+
 
73
  def flush(self):
74
+ pass
75
 
76
 
77
  def create_agent(
 
90
 
91
  if hf_token:
92
  print("[DEBUG] Logging into HuggingFace")
93
+ login(hf_token)
 
 
 
94
 
95
  model_params = {
96
  "model_id": model_id,
 
101
  if model_id == "gpt-4o-mini":
102
  model_params["reasoning_effort"] = "high"
103
 
104
+ if custom_api_endpoint and custom_api_key:
 
 
 
 
105
  print("[DEBUG] Using custom API endpoint:", custom_api_endpoint)
106
  model_params["base_url"] = custom_api_endpoint
107
  model_params["api_key"] = custom_api_key
 
 
 
 
 
108
 
109
  model = LiteLLMModel(**model_params)
110
  print("[DEBUG] Model initialized")
 
119
  "headers": {"User-Agent": user_agent},
120
  "timeout": 300,
121
  },
122
+ "serpapi_key": serpapi_key,
123
  }
124
 
125
  os.makedirs(f"./{browser_config['downloads_folder']}", exist_ok=True)
126
  browser = SimpleTextBrowser(**browser_config)
127
  print("[DEBUG] Browser initialized")
128
 
129
+ # Correct tool selection
130
  if search_provider == "searxng":
131
+ print("[DEBUG] Using SearxNG-compatible DuckDuckGoSearchTool with base_url override")
132
  search_tool = DuckDuckGoSearchTool()
133
  if custom_search_url:
134
+ search_tool.base_url = custom_search_url # Override default DuckDuckGo URL (only if supported)
 
 
 
 
 
 
 
 
 
 
 
 
135
  else:
136
+ print("[DEBUG] Using default DuckDuckGoSearchTool for Serper/standard search")
137
  search_tool = DuckDuckGoSearchTool()
138
 
 
139
  WEB_TOOLS = [
140
  search_tool,
141
  VisitTool(browser),
 
143
  PageDownTool(browser),
144
  FinderTool(browser),
145
  FindNextTool(browser),
146
+ ArchiveSearchTool(browser),
147
  TextInspectorTool(model, text_limit),
148
  ]
149
 
150
  text_webbrowser_agent = ToolCallingAgent(
151
  model=model,
152
+ tools=WEB_TOOLS,
153
  max_steps=20,
154
+ verbosity_level=3,
155
  planning_interval=4,
156
  name="search_agent",
157
  description="A team member that will search the internet to answer your question.",
 
166
  model=model,
167
  tools=[visualizer, TextInspectorTool(model, text_limit)],
168
  max_steps=16,
169
+ verbosity_level=3,
170
  additional_authorized_imports=AUTHORIZED_IMPORTS,
171
  planning_interval=4,
172
  managed_agents=[text_webbrowser_agent],
 
188
  root_logger = logging.getLogger()
189
  smolagents_logger = logging.getLogger('smolagents')
190
 
191
+ # Store original handlers
192
+ original_handlers = root_logger.handlers[:]
193
+ original_level = root_logger.level
 
 
 
 
 
 
194
 
 
 
 
 
 
 
 
195
  try:
196
  # Configure logging to capture everything
 
197
  root_logger.setLevel(logging.DEBUG)
 
 
198
  root_logger.addHandler(log_handler)
 
199
  smolagents_logger.setLevel(logging.DEBUG)
 
 
200
  smolagents_logger.addHandler(log_handler)
201
 
202
+ # Also capture stdout/stderr
203
+ stdout_capture = StreamingCapture()
204
+ stderr_capture = StreamingCapture()
 
 
 
 
 
205
 
206
  if stream_callback:
207
+ stdout_capture.add_callback(stream_callback)
208
+ stderr_capture.add_callback(stream_callback)
209
+
210
+ with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
211
+ if stream_callback:
212
+ stream_callback(f"[STARTING] Running agent with question: {question}\n")
213
+
214
+ answer = agent.run(question)
215
 
216
+ if stream_callback:
217
+ stream_callback(f"[COMPLETED] Final answer: {answer}\n")
218
+
219
+ return answer
220
 
221
  except Exception as e:
222
  error_msg = f"[ERROR] Exception occurred: {str(e)}\n"
 
225
  raise
226
  finally:
227
  # Restore original logging configuration
228
+ root_logger.handlers = original_handlers
229
+ root_logger.setLevel(original_level)
230
+ smolagents_logger.removeHandler(log_handler)
231
+
232
+
233
+ def create_gradio_interface():
234
+ """Create Gradio interface with streaming support"""
235
+ import gradio as gr
236
+ import time
237
+ import threading
238
+
239
+ def process_question(question, model_id, hf_token, serpapi_key, custom_api_endpoint,
240
+ custom_api_key, search_provider, search_api_key, custom_search_url):
241
+
242
+ # Create agent
243
+ agent = create_agent(
244
+ model_id=model_id,
245
+ hf_token=hf_token,
246
+ openai_api_key=None,
247
+ serpapi_key=serpapi_key,
248
+ api_endpoint=None,
249
+ custom_api_endpoint=custom_api_endpoint,
250
+ custom_api_key=custom_api_key,
251
+ search_provider=search_provider,
252
+ search_api_key=search_api_key,
253
+ custom_search_url=custom_search_url,
254
+ )
255
+
256
+ # Shared state for streaming
257
+ output_buffer = []
258
+ is_complete = False
259
+
260
+ def stream_callback(text):
261
+ output_buffer.append(text)
262
+
263
+ def run_agent_async():
264
+ nonlocal is_complete
265
+ try:
266
+ answer = run_agent_with_streaming(agent, question, stream_callback)
267
+ output_buffer.append(f"\n\n**FINAL ANSWER:** {answer}")
268
+ except Exception as e:
269
+ output_buffer.append(f"\n\n**ERROR:** {str(e)}")
270
+ finally:
271
+ is_complete = True
272
 
273
+ # Start agent in background thread
274
+ agent_thread = threading.Thread(target=run_agent_async)
275
+ agent_thread.start()
276
+
277
+ # Generator that yields updates
278
+ last_length = 0
279
+ while not is_complete or agent_thread.is_alive():
280
+ current_output = "".join(output_buffer)
281
+ if len(current_output) > last_length:
282
+ yield current_output
283
+ last_length = len(current_output)
284
+ time.sleep(0.1) # Small delay to prevent excessive updates
285
+
286
+ # Final yield to ensure everything is captured
287
+ final_output = "".join(output_buffer)
288
+ if len(final_output) > last_length:
289
+ yield final_output
290
+
291
+ # Create Gradio interface
292
+ with gr.Blocks(title="Streaming Agent Chat") as demo:
293
+ gr.Markdown("# Streaming Agent Chat Interface")
294
+
295
+ with gr.Row():
296
+ with gr.Column():
297
+ question_input = gr.Textbox(label="Question", placeholder="Enter your question here...")
298
+ model_id_input = gr.Textbox(label="Model ID", value="gpt-4o-mini")
299
+ hf_token_input = gr.Textbox(label="HuggingFace Token", type="password")
300
+ serpapi_key_input = gr.Textbox(label="SerpAPI Key", type="password")
301
+ custom_api_endpoint_input = gr.Textbox(label="Custom API Endpoint")
302
+ custom_api_key_input = gr.Textbox(label="Custom API Key", type="password")
303
+ search_provider_input = gr.Dropdown(
304
+ choices=["serper", "searxng"],
305
+ value="serper",
306
+ label="Search Provider"
307
+ )
308
+ search_api_key_input = gr.Textbox(label="Search API Key", type="password")
309
+ custom_search_url_input = gr.Textbox(label="Custom Search URL")
310
+
311
+ submit_btn = gr.Button("Submit", variant="primary")
312
+
313
+ with gr.Column():
314
+ output = gr.Textbox(
315
+ label="Agent Output (Streaming)",
316
+ lines=30,
317
+ max_lines=50,
318
+ interactive=False
319
+ )
320
+
321
+ submit_btn.click(
322
+ fn=process_question,
323
+ inputs=[
324
+ question_input, model_id_input, hf_token_input, serpapi_key_input,
325
+ custom_api_endpoint_input, custom_api_key_input, search_provider_input,
326
+ search_api_key_input, custom_search_url_input
327
+ ],
328
+ outputs=output,
329
+ show_progress=True
330
+ )
331
+
332
+ return demo
333
 
334
 
335
  def main():
 
342
  parser.add_argument("--model-id", type=str, default="gpt-4o-mini")
343
  parser.add_argument("--hf-token", type=str, default=os.getenv("HF_TOKEN"))
344
  parser.add_argument("--serpapi-key", type=str, default=os.getenv("SERPAPI_API_KEY"))
 
 
345
  parser.add_argument("--custom-api-endpoint", type=str, default=None)
346
  parser.add_argument("--custom-api-key", type=str, default=None)
347
+ parser.add_argument("--search-provider", type=str, default="serper")
348
  parser.add_argument("--search-api-key", type=str, default=None)
349
+ parser.add_argument("--custom-search-url", type=str, default=None)
350
  args = parser.parse_args()
351
 
352
  print("[DEBUG] CLI arguments parsed:", args)
353
 
354
  if args.gradio:
355
+ # Launch Gradio interface
356
+ demo = create_gradio_interface()
357
+ demo.launch(share=True)
358
  else:
359
  # CLI mode
360
  if not args.question:
 
364
  agent = create_agent(
365
  model_id=args.model_id,
366
  hf_token=args.hf_token,
367
+ openai_api_key=None, # Fix: was openai_api_token
368
  serpapi_key=args.serpapi_key,
369
+ api_endpoint=None, # Fix: was api_endpoint
370
  custom_api_endpoint=args.custom_api_endpoint,
371
  custom_api_key=args.custom_api_key,
372
  search_provider=args.search_provider,