Spaces:
Running
Running
File size: 30,934 Bytes
31add3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 |
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, ANY
import chainlit as cl
import sys
import app
from langchain_openai import ChatOpenAI
import langchain_openai
import inspect
from langchain_core.messages import AIMessageChunk
# Functions/classes to be tested (from app.py)
# We need to ensure app.py can be imported or we define these structures similarly
# For now, assuming app.on_message and app.InsightFlowState are accessible for mocking/patching
# from app import on_message, InsightFlowState # Ideal import
# Let's use the state definition from state.py for consistency
from insight_state import InsightFlowState
from utils.persona import PersonaFactory, PersonaReasoning
from langchain_core.messages import AIMessageChunk
# Placeholder for the actual on_message if it were in a separate module
# For now, we'll patch where it's defined (globally if app.py is run directly)
# This path might need adjustment based on how Chainlit runs the app
APP_PY_PATH = "app"
# Use the actual InsightFlowState for type consistency if needed, but primarily for session storage key
from insight_state import InsightFlowState
# Helper function to set up mock_cl behavior for these tests
def setup_mock_cl_session_get(mock_cl, direct_mode_value, quick_mode_value=False, initial_state_dict=None):
# This function will be called by cl.user_session.get(key)
default_state = {
"query": "", "selected_personas": [], "persona_responses": {},
"synthesized_response": None, "visualization_code": None,
"visualization_image_url": None, "current_step_name": "awaiting_query",
"error_message": None, "panel_type": "research"
}
current_state = initial_state_dict if initial_state_dict is not None else default_state
def side_effect_func(key, default=None): # Added default to match real signature
if key == "direct_mode":
return direct_mode_value
elif key == "quick_mode":
return quick_mode_value
elif key == "insight_flow_state":
return current_state
# Add other session variables if they are retrieved directly in on_message or invoke_langgraph
elif key == "persona_factory": # From test_invoke_langgraph_orchestrates_persona_execution
return session_data.get(key, MagicMock(spec=PersonaFactory)) if 'session_data' in globals() and session_data else MagicMock(spec=PersonaFactory)
elif key == "id": # From test_invoke_langgraph_orchestrates_persona_execution
return session_data.get(key, "default_thread_id") if 'session_data' in globals() and session_data else "default_thread_id"
return MagicMock() # Default for other keys
mock_cl.user_session.get.side_effect = side_effect_func
return current_state # Return the state that will be used, for convenience in tests
@pytest.mark.asyncio
async def test_on_message_direct_mode_on_calls_direct_llm(mock_cl): # Use mock_cl from conftest
"""Test that on_message calls the direct LLM if direct_mode is True."""
mock_incoming_message = MagicMock(spec=cl.Message)
mock_incoming_message.content = "Test query for direct mode"
setup_mock_cl_session_get(mock_cl, direct_mode_value=True)
# Ensure cl.Message().send() is an AsyncMock
mock_cl.Message.return_value.send = AsyncMock()
# Patch app.cl directly for on_message handler context
with patch(f'{APP_PY_PATH}.cl', new=mock_cl):
with patch(f'{APP_PY_PATH}.invoke_direct_llm', new_callable=AsyncMock) as mock_direct_call:
with patch(f'{APP_PY_PATH}.invoke_langgraph', new_callable=AsyncMock) as mock_graph_call:
from app import on_message # Import under patch context
await on_message(mock_incoming_message)
mock_direct_call.assert_called_once_with(mock_incoming_message.content)
mock_graph_call.assert_not_called()
mock_cl.user_session.get.assert_any_call("direct_mode")
@pytest.mark.asyncio
async def test_on_message_direct_mode_off_calls_langgraph(mock_cl): # Use mock_cl from conftest
"""Test that on_message calls LangGraph if direct_mode is False and progress messages are handled."""
mock_incoming_message = MagicMock(spec=cl.Message)
mock_incoming_message.content = "Test query for graph mode"
expected_initial_state = setup_mock_cl_session_get(mock_cl, direct_mode_value=False)
# --- Mock cl.Message for progress updates --- #
# We need to capture the instance of the progress message to check its methods.
mock_progress_message_instance = AsyncMock(spec=cl.Message) # Mock the instance
mock_progress_message_instance.send = AsyncMock()
mock_progress_message_instance.stream_token = AsyncMock()
mock_progress_message_instance.update = AsyncMock()
# Configure mock_cl.Message constructor to return our specific instance when content is ""
# For other cl.Message calls (e.g. in present_results), it can return a default MagicMock
default_mock_message_instance = AsyncMock(spec=cl.Message, send=AsyncMock())
def message_side_effect(*args, **kwargs):
if kwargs.get("content") == "": # This is how progress_msg is initialized
return mock_progress_message_instance
return default_mock_message_instance # For any other messages created
mock_cl.Message.side_effect = message_side_effect
mock_cl.Message.reset_mock() # Reset call count for the Message class itself from previous tests using mock_cl
# To store the actual progress message instance passed to the callback handler
passed_progress_msg_to_callback = None
# --- Mock InsightFlowCallbackHandler --- #
# We need to capture the arguments passed to its constructor, especially the progress_message.
# This class will be instantiated by the side_effect of our main mock.
class MockInsightFlowCallbackHandler(app.InsightFlowCallbackHandler):
def __init__(self, progress_message: cl.Message):
nonlocal passed_progress_msg_to_callback
passed_progress_msg_to_callback = progress_message
# super().__init__(progress_message) # We don't need to call super for this mock.
# The instance itself can be a simple object for type checking
# and argument capture verification.
self.progress_message = progress_message # Store it for potential inspection if needed
# This function will be the side_effect for the MagicMock replacing app.InsightFlowCallbackHandler
# It ensures our local MockInsightFlowCallbackHandler is created, capturing the args.
def mock_ifch_constructor_side_effect(progress_message):
return MockInsightFlowCallbackHandler(progress_message=progress_message)
# Patch app.cl (used by on_message and invoke_langgraph) and other app internals
with patch(f'{APP_PY_PATH}.cl', new=mock_cl): # Patches app.cl
with patch(f'{APP_PY_PATH}.invoke_direct_llm', new_callable=AsyncMock) as mock_direct_call:
# We want the real invoke_langgraph to run, so we patch what IT calls:
with patch(f'{APP_PY_PATH}.insight_flow_graph.ainvoke', new_callable=AsyncMock) as mock_graph_ainvoke:
# Patch the CallbackHandler class with a MagicMock.
# Its side_effect will use our local MockInsightFlowCallbackHandler for instantiation.
with patch(f'{APP_PY_PATH}.InsightFlowCallbackHandler') as mock_actual_ifch_class_constructor:
mock_actual_ifch_class_constructor.side_effect = mock_ifch_constructor_side_effect
from app import on_message # Import under patch context
await on_message(mock_incoming_message)
# --- Assertions --- #
mock_direct_call.assert_not_called()
# 1. Initial progress message creation and sending
# cl.Message(content="") should have been called by invoke_langgraph
mock_cl.Message.assert_any_call(content="")
mock_progress_message_instance.send.assert_called_once()
mock_progress_message_instance.stream_token.assert_any_call("⏳ Initializing InsightFlow process...")
# 2. Progress message stored in user session
mock_cl.user_session.set.assert_any_call("progress_msg", mock_progress_message_instance)
# 3. InsightFlowCallbackHandler instantiation
# Assert that the MagicMock representing the class constructor was called correctly.
mock_actual_ifch_class_constructor.assert_called_once_with(progress_message=mock_progress_message_instance)
assert passed_progress_msg_to_callback is mock_progress_message_instance # Ensure the correct msg object was captured
# 4. Graph invocation with callback handler
mock_graph_ainvoke.assert_called_once()
# Check state passed to graph (first arg of ainvoke)
actual_state_to_graph = mock_graph_ainvoke.call_args[0][0]
assert actual_state_to_graph["query"] == mock_incoming_message.content
# check config passed to graph (second arg of ainvoke, specifically callbacks)
actual_config_to_graph = mock_graph_ainvoke.call_args.kwargs['config']
assert len(actual_config_to_graph["callbacks"]) == 1
assert isinstance(actual_config_to_graph["callbacks"][0], MockInsightFlowCallbackHandler)
# 5. Final progress update
mock_progress_message_instance.stream_token.assert_any_call("\n✨ InsightFlow processing complete!")
mock_progress_message_instance.update.assert_called_once()
# 6. Session `get` calls (original assertions)
mock_cl.user_session.get.assert_any_call("direct_mode")
mock_cl.user_session.get.assert_any_call("insight_flow_state")
@pytest.mark.asyncio
async def test_invoke_langgraph_orchestrates_persona_execution(mock_cl):
"""
Test that invoke_langgraph correctly processes a query through the graph,
leading to execute_persona_tasks calling personas with their designated LLMs.
Mocks are placed at the LLM's astream method.
"""
query = "Tell me about black holes."
initial_selected_personas = ["analytical", "scientific"]
initial_insight_state = InsightFlowState(
query="", # Query will be set by invoke_langgraph
selected_personas=initial_selected_personas,
persona_responses={},
synthesized_response=None,
visualization_code=None,
visualization_image_url=None,
current_step_name="awaiting_query",
error_message=None,
panel_type="research"
)
# --- Mock cl.user_session.get specifically for this test ---
mock_persona_factory_instance = MagicMock(spec=PersonaFactory) # From utils.persona
# Mock the create_persona method on the factory instance
# It should now return *actual* PersonaReasoning instances, initialized with the passed (mocked) LLMs.
def actual_create_persona_side_effect(persona_id, llm_instance):
# This side effect will mimic the real factory's behavior of creating real PersonaReasoning objects,
# but using the llm_instance provided (which will be one of our fully mocked LLMs).
# We need the system prompts for the real PersonaReasoning constructor.
# For simplicity in this test, we can use dummy prompts or fetch from a minimal config.
# Or, ensure the mock_persona_factory_instance has a minimal persona_configs attribute for this.
# Let's assume PersonaReasoning can be created with the llm_instance directly for this test purpose
# if we mock its internal config loading or provide a simplified one.
# The key is that it returns a REAL PersonaReasoning that will call .astream on the llm_instance it received.
if persona_id == "analytical":
return PersonaReasoning(persona_id="analytical", name="Analytical (Test)", system_prompt="Analytical System Prompt", llm=llm_instance)
elif persona_id == "scientific":
return PersonaReasoning(persona_id="scientific", name="Scientific (Test)", system_prompt="Scientific System Prompt", llm=llm_instance)
return None
# Ensure mock_persona_factory_instance is a spec of PersonaFactory so it has create_persona
mock_persona_factory_instance.create_persona = MagicMock(side_effect=actual_create_persona_side_effect)
session_data = {
"persona_factory": mock_persona_factory_instance,
"id": "test_thread_id",
# "insight_flow_state": initial_insight_state # Not needed for invoke_langgraph direct call
}
mock_cl.user_session.get.side_effect = lambda key, default=None: session_data.get(key, default)
# --- Mock cl.Message for progress and results messages --- #
# Mock for the progress message created in invoke_langgraph
mock_progress_msg_invoke_lg = AsyncMock(spec=cl.Message)
mock_progress_msg_invoke_lg.send = AsyncMock()
mock_progress_msg_invoke_lg.stream_token = AsyncMock()
mock_progress_msg_invoke_lg.update = AsyncMock()
# Mock for messages created in present_results (or other general messages)
mock_other_msg_instance = AsyncMock(spec=cl.Message)
mock_other_msg_instance.send = AsyncMock()
# Add other methods like .stream_token or .update if present_results uses them directly
# For now, send is the primary one asserted for present_results in this test.
def message_constructor_side_effect(*args, **kwargs):
if kwargs.get("content") == "": # Progress message from invoke_langgraph
return mock_progress_msg_invoke_lg
# Here, you might add more conditions if present_results creates messages
# with specific content that needs a different mock. For now, this default works.
return mock_other_msg_instance
mock_cl.Message.side_effect = message_constructor_side_effect
# Ensure mock_cl.Message class itself can be checked for calls like assert_any_call(content="")
# We also need to reset call counts if mock_cl is shared across tests and cl.Message was called before.
mock_cl.Message.reset_mock()
# --- Patch LLMs used by PersonaReasoning via PERSONA_LLM_MAP in app.py ---
# We need to patch the actual LLM instances in app.py that PersonaReasoning will use.
# The PersonaReasoning objects themselves are created *during* the graph run.
# So, we mock the .astream method on the LLMs defined in app.py's PERSONA_LLM_MAP
# Define mock stream behavior
async def mock_llm_astream_analytical(*args, **kwargs):
yield AIMessageChunk(content="Analytical perspective ")
yield AIMessageChunk(content="on black holes.")
async def mock_llm_astream_scientific(*args, **kwargs):
yield AIMessageChunk(content="Scientific perspective ")
yield AIMessageChunk(content="on black holes.")
# --- Custom Async Iterator for mocking ---
class MockAsyncIterator:
def __init__(self, items_or_generator_func, *args, **kwargs):
# If a generator func is passed, call it to get the generator
if inspect.isasyncgenfunction(items_or_generator_func):
self.async_generator = items_or_generator_func(*args, **kwargs)
elif inspect.isasyncgen(items_or_generator_func): # if already a generator object
self.async_generator = items_or_generator_func
else: # assume it's a list of items to be wrapped
async def _gen():
for item in items_or_generator_func:
yield item
self.async_generator = _gen()
def __aiter__(self):
return self.async_generator # The async_generator itself is the iterator
async def __anext__(self):
# This is not strictly needed if __aiter__ returns a proper async generator
# as the generator handles its own __anext__.
# However, to be a complete async iterator, it could be defined.
# For now, relying on the returned async_generator from __aiter__.
raise NotImplementedError # Should not be called if __aiter__ returns a true async gen
# --- Function wrappers to ensure AsyncMock side_effect returns the async generator object directly ---
def analytical_astream_wrapper(*args, **kwargs):
return mock_llm_astream_analytical(*args, **kwargs)
def scientific_astream_wrapper(*args, **kwargs):
return mock_llm_astream_scientific(*args, **kwargs)
# Patching the astream method of the LLMs by replacing them in app.PERSONA_LLM_MAP
with patch.object(app, 'cl', new=mock_cl) as mock_cl_in_app:
from app import invoke_langgraph, PERSONA_LLM_MAP, llm_analytical, llm_scientific, llm_synthesizer, initialize_configurations # Import map and original LLMs for type reference
# Ensure configurations are initialized so llm_synthesizer is not None
initialize_configurations()
original_persona_llm_map = PERSONA_LLM_MAP.copy()
test_persona_llm_map = PERSONA_LLM_MAP.copy()
mock_analytical_llm_replacement = MagicMock(spec=ChatOpenAI)
analytical_astream_method_mock = AsyncMock(
return_value=MockAsyncIterator(mock_llm_astream_analytical)
)
mock_analytical_llm_replacement.astream = analytical_astream_method_mock
mock_scientific_llm_replacement = MagicMock(spec=ChatOpenAI)
scientific_astream_method_mock = AsyncMock(
return_value=MockAsyncIterator(mock_llm_astream_scientific)
)
mock_scientific_llm_replacement.astream = scientific_astream_method_mock
test_persona_llm_map["analytical"] = mock_analytical_llm_replacement
test_persona_llm_map["scientific"] = mock_scientific_llm_replacement
# Mock for llm_synthesizer.ainvoke
mock_synthesizer_response = MagicMock()
mock_synthesizer_response.content = "Synthesized view of black holes."
# Directly mock the ainvoke method on the app.llm_synthesizer instance
# This needs app.llm_synthesizer to be already initialized.
original_synthesizer_ainvoke = app.llm_synthesizer.ainvoke # Store original for restoration
app.llm_synthesizer.ainvoke = AsyncMock(return_value=mock_synthesizer_response)
# Patch the PERSONA_LLM_MAP in the app module.
# The ChatOpenAI.ainvoke class patch is removed as we are direct mocking instance.
mock_dalle_url = "http://fake_dalle_url.com/image.png"
mock_mermaid_code = "graph TD; A-->B;"
try:
with patch.dict(sys.modules['app'].__dict__, {"PERSONA_LLM_MAP": test_persona_llm_map}):
# Class patch for ChatOpenAI.ainvoke is removed here
with patch(f'{APP_PY_PATH}.generate_dalle_image', AsyncMock(return_value=mock_dalle_url)) as mock_gen_dalle:
with patch(f'{APP_PY_PATH}.generate_mermaid_code', AsyncMock(return_value=mock_mermaid_code)) as mock_gen_mermaid:
final_state = await invoke_langgraph(query, initial_insight_state)
finally:
# Restore the original ainvoke method
app.llm_synthesizer.ainvoke = original_synthesizer_ainvoke
# Assertions
mock_cl_in_app.user_session.get.assert_any_call("persona_factory")
mock_cl_in_app.user_session.get.assert_any_call("id", "default_thread_id")
# Check that create_persona on the factory was called correctly
# This was part of the original test logic, but the main check is that the LLMs get called.
# We can verify calls to the mock_persona_factory_instance.create_persona if needed,
# ensuring it was called with the correct persona_id and the *actual* LLM instance from app.PERSONA_LLM_MAP
# For now, primary assertion is on the LLM astream calls.
mock_persona_factory_instance.create_persona.assert_any_call("analytical", mock_analytical_llm_replacement)
mock_persona_factory_instance.create_persona.assert_any_call("scientific", mock_scientific_llm_replacement)
# Verify the astream method on our *replacement mock LLM instances* was called.
analytical_astream_method_mock.assert_called_once()
scientific_astream_method_mock.assert_called_once()
# Check final state content
assert final_state["query"] == query
assert "analytical" in final_state["persona_responses"]
assert final_state["persona_responses"]["analytical"] == "Analytical perspective on black holes."
assert "scientific" in final_state["persona_responses"]
assert final_state["persona_responses"]["scientific"] == "Scientific perspective on black holes."
assert final_state["current_step_name"] == "results_presented"
assert "Synthesized view" in final_state.get("synthesized_response", "") # Check it contains the key part
assert final_state.get("visualization_code") == mock_mermaid_code # Verify mocked mermaid code
assert final_state.get("visualization_image_url") == mock_dalle_url # Verify mocked DALL-E URL
# --- Assertions for Progress Message from invoke_langgraph ---
mock_cl.Message.assert_any_call(content="") # Initial progress message creation
mock_progress_msg_invoke_lg.send.assert_called_once() # Sent initially
mock_progress_msg_invoke_lg.stream_token.assert_any_call("⏳ Initializing InsightFlow process...")
# Add assertion for session set if invoke_langgraph sets progress_msg in session for this test's path
# mock_cl.user_session.set.assert_any_call("progress_msg", mock_progress_msg_invoke_lg)
mock_progress_msg_invoke_lg.stream_token.assert_any_call("\n✨ InsightFlow processing complete!")
mock_progress_msg_invoke_lg.update.assert_called_once()
# Check that present_results sent a message (using the mock_other_msg_instance)
# This assertion assumes present_results creates one message. If it creates multiple, this needs adjustment.
# If present_results creates messages with different content, the side_effect might need to be more specific.
assert mock_other_msg_instance.send.called # Check if send was called on the message from present_results
# If only one message is expected from present_results:
# mock_other_msg_instance.send.assert_called_once()
# --- Assertions for Visualization Function Calls ---
mock_gen_dalle.assert_called_once()
# We can add more specific assertions about the arguments if needed, e.g.,
# mock_gen_dalle.assert_called_once_with(prompt=ANY, client=ANY)
mock_gen_mermaid.assert_called_once()
# mock_gen_mermaid.assert_called_once_with(text_input=ANY, llm_client=ANY)
@pytest.mark.asyncio
async def test_on_message_quick_mode_on_overrides_personas(mock_cl):
"""Test that on_message with quick_mode=True overrides selected_personas before calling invoke_langgraph."""
mock_incoming_message = MagicMock(spec=cl.Message)
mock_incoming_message.content = "Test query for quick mode on"
initial_personas = ["creative", "historical"]
initial_state_dict = {
"query": "", "selected_personas": initial_personas, "persona_responses": {},
"synthesized_response": None, "visualization_code": None,
"visualization_image_url": None, "current_step_name": "awaiting_query",
"error_message": None, "panel_type": "research"
}
setup_mock_cl_session_get(mock_cl, direct_mode_value=False, quick_mode_value=True, initial_state_dict=initial_state_dict)
# --- Mock cl.Message for progress updates (similar to test_on_message_direct_mode_off_calls_langgraph) --- #
mock_progress_message_instance = AsyncMock(spec=cl.Message)
mock_progress_message_instance.send = AsyncMock()
mock_progress_message_instance.stream_token = AsyncMock()
mock_progress_message_instance.update = AsyncMock()
default_mock_other_message_instance = AsyncMock(spec=cl.Message, send=AsyncMock())
def message_side_effect_quick_on(*args, **kwargs):
if kwargs.get("content") == "":
return mock_progress_message_instance
return default_mock_other_message_instance
mock_cl.Message.side_effect = message_side_effect_quick_on
mock_cl.Message.reset_mock() # Reset from other tests
expected_quick_mode_personas = ["test_quick1", "test_quick2"]
passed_progress_msg_to_callback_quick_on = None
# Define a unique local mock handler class for this test to avoid nonlocal conflicts
class MockCBHandlerQuickOnLocal(app.InsightFlowCallbackHandler): # Changed name
def __init__(self, progress_message: cl.Message):
nonlocal passed_progress_msg_to_callback_quick_on
passed_progress_msg_to_callback_quick_on = progress_message
# super().__init__(progress_message) # Not strictly necessary for the mock's role
self.progress_message = progress_message
def mock_ifch_constructor_side_effect_quick_on(progress_message):
return MockCBHandlerQuickOnLocal(progress_message=progress_message)
with patch(f'{APP_PY_PATH}.cl', new=mock_cl):
with patch(f'{APP_PY_PATH}.QUICK_MODE_PERSONAS', new=expected_quick_mode_personas):
with patch(f'{APP_PY_PATH}.insight_flow_graph.ainvoke', new_callable=AsyncMock) as mock_graph_ainvoke:
# Patch InsightFlowCallbackHandler with a MagicMock whose side_effect instantiates the local mock
with patch(f'{APP_PY_PATH}.InsightFlowCallbackHandler') as mock_ifch_constructor_quick_on:
mock_ifch_constructor_quick_on.side_effect = mock_ifch_constructor_side_effect_quick_on
from app import on_message
await on_message(mock_incoming_message)
# Progress Message Assertions
mock_cl.Message.assert_any_call(content="")
mock_progress_message_instance.send.assert_called_once()
mock_progress_message_instance.stream_token.assert_any_call("⏳ Initializing InsightFlow process...")
mock_cl.user_session.set.assert_any_call("progress_msg", mock_progress_message_instance)
# Assert on the MagicMock that replaced the class constructor
mock_ifch_constructor_quick_on.assert_called_once_with(progress_message=mock_progress_message_instance)
assert passed_progress_msg_to_callback_quick_on is mock_progress_message_instance # Check captured arg
mock_progress_message_instance.stream_token.assert_any_call("\n✨ InsightFlow processing complete!")
mock_progress_message_instance.update.assert_called_once()
# Original Assertions
mock_graph_ainvoke.assert_called_once()
called_state = mock_graph_ainvoke.call_args[0][0]
assert called_state["selected_personas"] == expected_quick_mode_personas
assert called_state["query"] == mock_incoming_message.content
mock_cl.user_session.get.assert_any_call("direct_mode")
mock_cl.user_session.get.assert_any_call("quick_mode", False)
mock_cl.user_session.get.assert_any_call("insight_flow_state")
@pytest.mark.asyncio
async def test_on_message_quick_mode_off_uses_original_personas(mock_cl):
"""Test that on_message with quick_mode=False uses original selected_personas for invoke_langgraph."""
mock_incoming_message = MagicMock(spec=cl.Message)
mock_incoming_message.content = "Test query for quick mode off"
original_selected_personas = ["original1", "original2"]
initial_state_dict_off = {
"query": "", "selected_personas": original_selected_personas, "persona_responses": {},
"synthesized_response": None, "visualization_code": None,
"visualization_image_url": None, "current_step_name": "awaiting_query",
"error_message": None, "panel_type": "research"
}
setup_mock_cl_session_get(mock_cl, direct_mode_value=False, quick_mode_value=False, initial_state_dict=initial_state_dict_off)
mock_progress_message_instance_off = AsyncMock(spec=cl.Message)
mock_progress_message_instance_off.send = AsyncMock()
mock_progress_message_instance_off.stream_token = AsyncMock()
mock_progress_message_instance_off.update = AsyncMock()
default_mock_other_message_instance_off = AsyncMock(spec=cl.Message, send=AsyncMock())
def message_side_effect_quick_off(*args, **kwargs):
if kwargs.get("content") == "":
return mock_progress_message_instance_off
return default_mock_other_message_instance_off
mock_cl.Message.side_effect = message_side_effect_quick_off
mock_cl.Message.reset_mock()
passed_progress_msg_to_callback_quick_off = None
# Define a unique local mock handler class for this test
class MockCBHandlerQuickOffLocal(app.InsightFlowCallbackHandler): # Changed name
def __init__(self, progress_message: cl.Message):
nonlocal passed_progress_msg_to_callback_quick_off
passed_progress_msg_to_callback_quick_off = progress_message
# super().__init__(progress_message) # Not strictly necessary
self.progress_message = progress_message
def mock_ifch_constructor_side_effect_quick_off(progress_message):
return MockCBHandlerQuickOffLocal(progress_message=progress_message)
with patch(f'{APP_PY_PATH}.cl', new=mock_cl):
with patch(f'{APP_PY_PATH}.insight_flow_graph.ainvoke', new_callable=AsyncMock) as mock_graph_ainvoke:
# Patch InsightFlowCallbackHandler with a MagicMock whose side_effect instantiates the local mock
with patch(f'{APP_PY_PATH}.InsightFlowCallbackHandler') as mock_ifch_constructor_quick_off:
mock_ifch_constructor_quick_off.side_effect = mock_ifch_constructor_side_effect_quick_off
from app import on_message
await on_message(mock_incoming_message)
# Progress Message Assertions
mock_cl.Message.assert_any_call(content="")
mock_progress_message_instance_off.send.assert_called_once()
mock_progress_message_instance_off.stream_token.assert_any_call("⏳ Initializing InsightFlow process...")
mock_cl.user_session.set.assert_any_call("progress_msg", mock_progress_message_instance_off)
# Assert on the MagicMock that replaced the class constructor
mock_ifch_constructor_quick_off.assert_called_once_with(progress_message=mock_progress_message_instance_off)
assert passed_progress_msg_to_callback_quick_off is mock_progress_message_instance_off # Check captured arg
mock_progress_message_instance_off.stream_token.assert_any_call("\n✨ InsightFlow processing complete!")
mock_progress_message_instance_off.update.assert_called_once()
# Original Assertions
mock_graph_ainvoke.assert_called_once()
called_state = mock_graph_ainvoke.call_args[0][0]
assert called_state["selected_personas"] == original_selected_personas
assert called_state["query"] == mock_incoming_message.content
mock_cl.user_session.get.assert_any_call("direct_mode")
mock_cl.user_session.get.assert_any_call("quick_mode", False)
mock_cl.user_session.get.assert_any_call("insight_flow_state") |