File size: 30,934 Bytes
31add3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
import pytest
from unittest.mock import AsyncMock, MagicMock, patch, ANY
import chainlit as cl
import sys
import app
from langchain_openai import ChatOpenAI
import langchain_openai
import inspect
from langchain_core.messages import AIMessageChunk

# Functions/classes to be tested (from app.py)
# We need to ensure app.py can be imported or we define these structures similarly
# For now, assuming app.on_message and app.InsightFlowState are accessible for mocking/patching
# from app import on_message, InsightFlowState # Ideal import

# Let's use the state definition from state.py for consistency
from insight_state import InsightFlowState
from utils.persona import PersonaFactory, PersonaReasoning
from langchain_core.messages import AIMessageChunk

# Placeholder for the actual on_message if it were in a separate module
# For now, we'll patch where it's defined (globally if app.py is run directly)
# This path might need adjustment based on how Chainlit runs the app
APP_PY_PATH = "app"

# Use the actual InsightFlowState for type consistency if needed, but primarily for session storage key
from insight_state import InsightFlowState

# Helper function to set up mock_cl behavior for these tests
def setup_mock_cl_session_get(mock_cl, direct_mode_value, quick_mode_value=False, initial_state_dict=None):
    # This function will be called by cl.user_session.get(key)
    
    default_state = {
        "query": "", "selected_personas": [], "persona_responses": {},
        "synthesized_response": None, "visualization_code": None,
        "visualization_image_url": None, "current_step_name": "awaiting_query",
        "error_message": None, "panel_type": "research"
    }
    current_state = initial_state_dict if initial_state_dict is not None else default_state

    def side_effect_func(key, default=None): # Added default to match real signature
        if key == "direct_mode":
            return direct_mode_value
        elif key == "quick_mode":
            return quick_mode_value
        elif key == "insight_flow_state":
            return current_state
        # Add other session variables if they are retrieved directly in on_message or invoke_langgraph
        elif key == "persona_factory": # From test_invoke_langgraph_orchestrates_persona_execution
            return session_data.get(key, MagicMock(spec=PersonaFactory)) if 'session_data' in globals() and session_data else MagicMock(spec=PersonaFactory)
        elif key == "id": # From test_invoke_langgraph_orchestrates_persona_execution
             return session_data.get(key, "default_thread_id") if 'session_data' in globals() and session_data else "default_thread_id"
        return MagicMock() # Default for other keys
    mock_cl.user_session.get.side_effect = side_effect_func
    return current_state # Return the state that will be used, for convenience in tests

@pytest.mark.asyncio
async def test_on_message_direct_mode_on_calls_direct_llm(mock_cl): # Use mock_cl from conftest
    """Test that on_message calls the direct LLM if direct_mode is True."""
    mock_incoming_message = MagicMock(spec=cl.Message)
    mock_incoming_message.content = "Test query for direct mode"

    setup_mock_cl_session_get(mock_cl, direct_mode_value=True)
    # Ensure cl.Message().send() is an AsyncMock
    mock_cl.Message.return_value.send = AsyncMock()
    
    # Patch app.cl directly for on_message handler context
    with patch(f'{APP_PY_PATH}.cl', new=mock_cl):
        with patch(f'{APP_PY_PATH}.invoke_direct_llm', new_callable=AsyncMock) as mock_direct_call:
            with patch(f'{APP_PY_PATH}.invoke_langgraph', new_callable=AsyncMock) as mock_graph_call:
                from app import on_message # Import under patch context
                await on_message(mock_incoming_message)

    mock_direct_call.assert_called_once_with(mock_incoming_message.content)
    mock_graph_call.assert_not_called()
    mock_cl.user_session.get.assert_any_call("direct_mode")

@pytest.mark.asyncio
async def test_on_message_direct_mode_off_calls_langgraph(mock_cl): # Use mock_cl from conftest
    """Test that on_message calls LangGraph if direct_mode is False and progress messages are handled."""
    mock_incoming_message = MagicMock(spec=cl.Message)
    mock_incoming_message.content = "Test query for graph mode"

    expected_initial_state = setup_mock_cl_session_get(mock_cl, direct_mode_value=False)
    
    # --- Mock cl.Message for progress updates --- #
    # We need to capture the instance of the progress message to check its methods.
    mock_progress_message_instance = AsyncMock(spec=cl.Message) # Mock the instance
    mock_progress_message_instance.send = AsyncMock()
    mock_progress_message_instance.stream_token = AsyncMock()
    mock_progress_message_instance.update = AsyncMock()

    # Configure mock_cl.Message constructor to return our specific instance when content is ""
    # For other cl.Message calls (e.g. in present_results), it can return a default MagicMock
    default_mock_message_instance = AsyncMock(spec=cl.Message, send=AsyncMock())
    def message_side_effect(*args, **kwargs):
        if kwargs.get("content") == "": # This is how progress_msg is initialized
            return mock_progress_message_instance
        return default_mock_message_instance # For any other messages created
    mock_cl.Message.side_effect = message_side_effect
    mock_cl.Message.reset_mock() # Reset call count for the Message class itself from previous tests using mock_cl

    # To store the actual progress message instance passed to the callback handler
    passed_progress_msg_to_callback = None

    # --- Mock InsightFlowCallbackHandler --- #    
    # We need to capture the arguments passed to its constructor, especially the progress_message.
    # This class will be instantiated by the side_effect of our main mock.
    class MockInsightFlowCallbackHandler(app.InsightFlowCallbackHandler):
        def __init__(self, progress_message: cl.Message):
            nonlocal passed_progress_msg_to_callback
            passed_progress_msg_to_callback = progress_message
            # super().__init__(progress_message) # We don't need to call super for this mock.
                                            # The instance itself can be a simple object for type checking
                                            # and argument capture verification.
            self.progress_message = progress_message # Store it for potential inspection if needed

    # This function will be the side_effect for the MagicMock replacing app.InsightFlowCallbackHandler
    # It ensures our local MockInsightFlowCallbackHandler is created, capturing the args.
    def mock_ifch_constructor_side_effect(progress_message):
        return MockInsightFlowCallbackHandler(progress_message=progress_message)

    # Patch app.cl (used by on_message and invoke_langgraph) and other app internals
    with patch(f'{APP_PY_PATH}.cl', new=mock_cl): # Patches app.cl
        with patch(f'{APP_PY_PATH}.invoke_direct_llm', new_callable=AsyncMock) as mock_direct_call:
            # We want the real invoke_langgraph to run, so we patch what IT calls:
            with patch(f'{APP_PY_PATH}.insight_flow_graph.ainvoke', new_callable=AsyncMock) as mock_graph_ainvoke:
                # Patch the CallbackHandler class with a MagicMock.
                # Its side_effect will use our local MockInsightFlowCallbackHandler for instantiation.
                with patch(f'{APP_PY_PATH}.InsightFlowCallbackHandler') as mock_actual_ifch_class_constructor:
                    mock_actual_ifch_class_constructor.side_effect = mock_ifch_constructor_side_effect
                    from app import on_message # Import under patch context
                    await on_message(mock_incoming_message)

    # --- Assertions --- #
    mock_direct_call.assert_not_called()
    
    # 1. Initial progress message creation and sending
    #    cl.Message(content="") should have been called by invoke_langgraph
    mock_cl.Message.assert_any_call(content="")
    mock_progress_message_instance.send.assert_called_once()
    mock_progress_message_instance.stream_token.assert_any_call("⏳ Initializing InsightFlow process...")

    # 2. Progress message stored in user session
    mock_cl.user_session.set.assert_any_call("progress_msg", mock_progress_message_instance)

    # 3. InsightFlowCallbackHandler instantiation
    # Assert that the MagicMock representing the class constructor was called correctly.
    mock_actual_ifch_class_constructor.assert_called_once_with(progress_message=mock_progress_message_instance)
    assert passed_progress_msg_to_callback is mock_progress_message_instance # Ensure the correct msg object was captured

    # 4. Graph invocation with callback handler
    mock_graph_ainvoke.assert_called_once()
    # Check state passed to graph (first arg of ainvoke)
    actual_state_to_graph = mock_graph_ainvoke.call_args[0][0]
    assert actual_state_to_graph["query"] == mock_incoming_message.content
    # check config passed to graph (second arg of ainvoke, specifically callbacks)
    actual_config_to_graph = mock_graph_ainvoke.call_args.kwargs['config']
    assert len(actual_config_to_graph["callbacks"]) == 1
    assert isinstance(actual_config_to_graph["callbacks"][0], MockInsightFlowCallbackHandler)

    # 5. Final progress update
    mock_progress_message_instance.stream_token.assert_any_call("\n✨ InsightFlow processing complete!")
    mock_progress_message_instance.update.assert_called_once()
    
    # 6. Session `get` calls (original assertions)
    mock_cl.user_session.get.assert_any_call("direct_mode")
    mock_cl.user_session.get.assert_any_call("insight_flow_state") 

@pytest.mark.asyncio
async def test_invoke_langgraph_orchestrates_persona_execution(mock_cl):
    """
    Test that invoke_langgraph correctly processes a query through the graph,
    leading to execute_persona_tasks calling personas with their designated LLMs.
    Mocks are placed at the LLM's astream method.
    """
    query = "Tell me about black holes."
    initial_selected_personas = ["analytical", "scientific"]
    initial_insight_state = InsightFlowState(
        query="", # Query will be set by invoke_langgraph
        selected_personas=initial_selected_personas,
        persona_responses={},
        synthesized_response=None,
        visualization_code=None,
        visualization_image_url=None,
        current_step_name="awaiting_query",
        error_message=None,
        panel_type="research"
    )

    # --- Mock cl.user_session.get specifically for this test ---
    mock_persona_factory_instance = MagicMock(spec=PersonaFactory) # From utils.persona
    
    # Mock the create_persona method on the factory instance
    # It should now return *actual* PersonaReasoning instances, initialized with the passed (mocked) LLMs.
    def actual_create_persona_side_effect(persona_id, llm_instance):
        # This side effect will mimic the real factory's behavior of creating real PersonaReasoning objects,
        # but using the llm_instance provided (which will be one of our fully mocked LLMs).
        # We need the system prompts for the real PersonaReasoning constructor.
        # For simplicity in this test, we can use dummy prompts or fetch from a minimal config.
        # Or, ensure the mock_persona_factory_instance has a minimal persona_configs attribute for this.
        # Let's assume PersonaReasoning can be created with the llm_instance directly for this test purpose
        # if we mock its internal config loading or provide a simplified one.

        # The key is that it returns a REAL PersonaReasoning that will call .astream on the llm_instance it received.
        if persona_id == "analytical":
            return PersonaReasoning(persona_id="analytical", name="Analytical (Test)", system_prompt="Analytical System Prompt", llm=llm_instance)
        elif persona_id == "scientific":
            return PersonaReasoning(persona_id="scientific", name="Scientific (Test)", system_prompt="Scientific System Prompt", llm=llm_instance)
        return None
    
    # Ensure mock_persona_factory_instance is a spec of PersonaFactory so it has create_persona
    mock_persona_factory_instance.create_persona = MagicMock(side_effect=actual_create_persona_side_effect)

    session_data = {
        "persona_factory": mock_persona_factory_instance,
        "id": "test_thread_id",
        # "insight_flow_state": initial_insight_state # Not needed for invoke_langgraph direct call
    }
    mock_cl.user_session.get.side_effect = lambda key, default=None: session_data.get(key, default)
    
    # --- Mock cl.Message for progress and results messages --- #
    # Mock for the progress message created in invoke_langgraph
    mock_progress_msg_invoke_lg = AsyncMock(spec=cl.Message)
    mock_progress_msg_invoke_lg.send = AsyncMock()
    mock_progress_msg_invoke_lg.stream_token = AsyncMock()
    mock_progress_msg_invoke_lg.update = AsyncMock()

    # Mock for messages created in present_results (or other general messages)
    mock_other_msg_instance = AsyncMock(spec=cl.Message)
    mock_other_msg_instance.send = AsyncMock()
    # Add other methods like .stream_token or .update if present_results uses them directly
    # For now, send is the primary one asserted for present_results in this test.

    def message_constructor_side_effect(*args, **kwargs):
        if kwargs.get("content") == "": # Progress message from invoke_langgraph
            return mock_progress_msg_invoke_lg
        # Here, you might add more conditions if present_results creates messages
        # with specific content that needs a different mock. For now, this default works.
        return mock_other_msg_instance

    mock_cl.Message.side_effect = message_constructor_side_effect
    # Ensure mock_cl.Message class itself can be checked for calls like assert_any_call(content="")
    # We also need to reset call counts if mock_cl is shared across tests and cl.Message was called before.
    mock_cl.Message.reset_mock() 

    # --- Patch LLMs used by PersonaReasoning via PERSONA_LLM_MAP in app.py ---
    # We need to patch the actual LLM instances in app.py that PersonaReasoning will use.
    # The PersonaReasoning objects themselves are created *during* the graph run.
    # So, we mock the .astream method on the LLMs defined in app.py's PERSONA_LLM_MAP

    # Define mock stream behavior
    async def mock_llm_astream_analytical(*args, **kwargs):
        yield AIMessageChunk(content="Analytical perspective ")
        yield AIMessageChunk(content="on black holes.")

    async def mock_llm_astream_scientific(*args, **kwargs):
        yield AIMessageChunk(content="Scientific perspective ")
        yield AIMessageChunk(content="on black holes.")

    # --- Custom Async Iterator for mocking --- 
    class MockAsyncIterator:
        def __init__(self, items_or_generator_func, *args, **kwargs):
            # If a generator func is passed, call it to get the generator
            if inspect.isasyncgenfunction(items_or_generator_func):
                self.async_generator = items_or_generator_func(*args, **kwargs)
            elif inspect.isasyncgen(items_or_generator_func): # if already a generator object
                self.async_generator = items_or_generator_func
            else: # assume it's a list of items to be wrapped
                async def _gen():
                    for item in items_or_generator_func:
                        yield item
                self.async_generator = _gen()

        def __aiter__(self):
            return self.async_generator # The async_generator itself is the iterator

        async def __anext__(self):
            # This is not strictly needed if __aiter__ returns a proper async generator
            # as the generator handles its own __anext__.
            # However, to be a complete async iterator, it could be defined.
            # For now, relying on the returned async_generator from __aiter__.
            raise NotImplementedError # Should not be called if __aiter__ returns a true async gen

    # --- Function wrappers to ensure AsyncMock side_effect returns the async generator object directly ---
    def analytical_astream_wrapper(*args, **kwargs):
        return mock_llm_astream_analytical(*args, **kwargs)
    
    def scientific_astream_wrapper(*args, **kwargs):
        return mock_llm_astream_scientific(*args, **kwargs)

    # Patching the astream method of the LLMs by replacing them in app.PERSONA_LLM_MAP
    with patch.object(app, 'cl', new=mock_cl) as mock_cl_in_app:
        from app import invoke_langgraph, PERSONA_LLM_MAP, llm_analytical, llm_scientific, llm_synthesizer, initialize_configurations # Import map and original LLMs for type reference
        
        # Ensure configurations are initialized so llm_synthesizer is not None
        initialize_configurations() 

        original_persona_llm_map = PERSONA_LLM_MAP.copy()
        test_persona_llm_map = PERSONA_LLM_MAP.copy()

        mock_analytical_llm_replacement = MagicMock(spec=ChatOpenAI)
        analytical_astream_method_mock = AsyncMock(
            return_value=MockAsyncIterator(mock_llm_astream_analytical)
        )
        mock_analytical_llm_replacement.astream = analytical_astream_method_mock

        mock_scientific_llm_replacement = MagicMock(spec=ChatOpenAI)
        scientific_astream_method_mock = AsyncMock(
            return_value=MockAsyncIterator(mock_llm_astream_scientific)
        )
        mock_scientific_llm_replacement.astream = scientific_astream_method_mock

        test_persona_llm_map["analytical"] = mock_analytical_llm_replacement
        test_persona_llm_map["scientific"] = mock_scientific_llm_replacement

        # Mock for llm_synthesizer.ainvoke
        mock_synthesizer_response = MagicMock()
        mock_synthesizer_response.content = "Synthesized view of black holes."
        
        # Directly mock the ainvoke method on the app.llm_synthesizer instance
        # This needs app.llm_synthesizer to be already initialized.
        original_synthesizer_ainvoke = app.llm_synthesizer.ainvoke # Store original for restoration
        app.llm_synthesizer.ainvoke = AsyncMock(return_value=mock_synthesizer_response)

        # Patch the PERSONA_LLM_MAP in the app module.
        # The ChatOpenAI.ainvoke class patch is removed as we are direct mocking instance.
        mock_dalle_url = "http://fake_dalle_url.com/image.png"
        mock_mermaid_code = "graph TD; A-->B;"

        try:
            with patch.dict(sys.modules['app'].__dict__, {"PERSONA_LLM_MAP": test_persona_llm_map}):
                # Class patch for ChatOpenAI.ainvoke is removed here
                with patch(f'{APP_PY_PATH}.generate_dalle_image', AsyncMock(return_value=mock_dalle_url)) as mock_gen_dalle:
                    with patch(f'{APP_PY_PATH}.generate_mermaid_code', AsyncMock(return_value=mock_mermaid_code)) as mock_gen_mermaid:
                        final_state = await invoke_langgraph(query, initial_insight_state)
        finally:
            # Restore the original ainvoke method
            app.llm_synthesizer.ainvoke = original_synthesizer_ainvoke

    # Assertions
    mock_cl_in_app.user_session.get.assert_any_call("persona_factory")
    mock_cl_in_app.user_session.get.assert_any_call("id", "default_thread_id")

    # Check that create_persona on the factory was called correctly
    # This was part of the original test logic, but the main check is that the LLMs get called.
    # We can verify calls to the mock_persona_factory_instance.create_persona if needed,
    # ensuring it was called with the correct persona_id and the *actual* LLM instance from app.PERSONA_LLM_MAP
    # For now, primary assertion is on the LLM astream calls.
    mock_persona_factory_instance.create_persona.assert_any_call("analytical", mock_analytical_llm_replacement)
    mock_persona_factory_instance.create_persona.assert_any_call("scientific", mock_scientific_llm_replacement)

    # Verify the astream method on our *replacement mock LLM instances* was called.
    analytical_astream_method_mock.assert_called_once()
    scientific_astream_method_mock.assert_called_once()

    # Check final state content
    assert final_state["query"] == query
    assert "analytical" in final_state["persona_responses"]
    assert final_state["persona_responses"]["analytical"] == "Analytical perspective on black holes."
    assert "scientific" in final_state["persona_responses"]
    assert final_state["persona_responses"]["scientific"] == "Scientific perspective on black holes."
    assert final_state["current_step_name"] == "results_presented"
    assert "Synthesized view" in final_state.get("synthesized_response", "") # Check it contains the key part
    assert final_state.get("visualization_code") == mock_mermaid_code # Verify mocked mermaid code
    assert final_state.get("visualization_image_url") == mock_dalle_url # Verify mocked DALL-E URL
    
    # --- Assertions for Progress Message from invoke_langgraph ---
    mock_cl.Message.assert_any_call(content="") # Initial progress message creation
    mock_progress_msg_invoke_lg.send.assert_called_once() # Sent initially
    mock_progress_msg_invoke_lg.stream_token.assert_any_call("⏳ Initializing InsightFlow process...")
    # Add assertion for session set if invoke_langgraph sets progress_msg in session for this test's path
    # mock_cl.user_session.set.assert_any_call("progress_msg", mock_progress_msg_invoke_lg)
    mock_progress_msg_invoke_lg.stream_token.assert_any_call("\n✨ InsightFlow processing complete!")
    mock_progress_msg_invoke_lg.update.assert_called_once()

    # Check that present_results sent a message (using the mock_other_msg_instance)
    # This assertion assumes present_results creates one message. If it creates multiple, this needs adjustment.
    # If present_results creates messages with different content, the side_effect might need to be more specific.
    assert mock_other_msg_instance.send.called # Check if send was called on the message from present_results
    # If only one message is expected from present_results:
    # mock_other_msg_instance.send.assert_called_once() 

    # --- Assertions for Visualization Function Calls ---
    mock_gen_dalle.assert_called_once()
    # We can add more specific assertions about the arguments if needed, e.g.,
    # mock_gen_dalle.assert_called_once_with(prompt=ANY, client=ANY)
    mock_gen_mermaid.assert_called_once()
    # mock_gen_mermaid.assert_called_once_with(text_input=ANY, llm_client=ANY)

@pytest.mark.asyncio
async def test_on_message_quick_mode_on_overrides_personas(mock_cl):
    """Test that on_message with quick_mode=True overrides selected_personas before calling invoke_langgraph."""
    mock_incoming_message = MagicMock(spec=cl.Message)
    mock_incoming_message.content = "Test query for quick mode on"

    initial_personas = ["creative", "historical"]
    initial_state_dict = {
        "query": "", "selected_personas": initial_personas, "persona_responses": {},
        "synthesized_response": None, "visualization_code": None,
        "visualization_image_url": None, "current_step_name": "awaiting_query",
        "error_message": None, "panel_type": "research"
    }
    
    setup_mock_cl_session_get(mock_cl, direct_mode_value=False, quick_mode_value=True, initial_state_dict=initial_state_dict)
    
    # --- Mock cl.Message for progress updates (similar to test_on_message_direct_mode_off_calls_langgraph) --- #
    mock_progress_message_instance = AsyncMock(spec=cl.Message)
    mock_progress_message_instance.send = AsyncMock()
    mock_progress_message_instance.stream_token = AsyncMock()
    mock_progress_message_instance.update = AsyncMock()

    default_mock_other_message_instance = AsyncMock(spec=cl.Message, send=AsyncMock())
    def message_side_effect_quick_on(*args, **kwargs):
        if kwargs.get("content") == "": 
            return mock_progress_message_instance
        return default_mock_other_message_instance
    mock_cl.Message.side_effect = message_side_effect_quick_on
    mock_cl.Message.reset_mock() # Reset from other tests

    expected_quick_mode_personas = ["test_quick1", "test_quick2"]
    passed_progress_msg_to_callback_quick_on = None

    # Define a unique local mock handler class for this test to avoid nonlocal conflicts
    class MockCBHandlerQuickOnLocal(app.InsightFlowCallbackHandler): # Changed name
        def __init__(self, progress_message: cl.Message):
            nonlocal passed_progress_msg_to_callback_quick_on
            passed_progress_msg_to_callback_quick_on = progress_message
            # super().__init__(progress_message) # Not strictly necessary for the mock's role
            self.progress_message = progress_message

    def mock_ifch_constructor_side_effect_quick_on(progress_message):
        return MockCBHandlerQuickOnLocal(progress_message=progress_message)

    with patch(f'{APP_PY_PATH}.cl', new=mock_cl):
        with patch(f'{APP_PY_PATH}.QUICK_MODE_PERSONAS', new=expected_quick_mode_personas):
            with patch(f'{APP_PY_PATH}.insight_flow_graph.ainvoke', new_callable=AsyncMock) as mock_graph_ainvoke:
                # Patch InsightFlowCallbackHandler with a MagicMock whose side_effect instantiates the local mock
                with patch(f'{APP_PY_PATH}.InsightFlowCallbackHandler') as mock_ifch_constructor_quick_on:
                    mock_ifch_constructor_quick_on.side_effect = mock_ifch_constructor_side_effect_quick_on
                    from app import on_message
                    await on_message(mock_incoming_message)

    # Progress Message Assertions
    mock_cl.Message.assert_any_call(content="")
    mock_progress_message_instance.send.assert_called_once()
    mock_progress_message_instance.stream_token.assert_any_call("⏳ Initializing InsightFlow process...")
    mock_cl.user_session.set.assert_any_call("progress_msg", mock_progress_message_instance)
    # Assert on the MagicMock that replaced the class constructor
    mock_ifch_constructor_quick_on.assert_called_once_with(progress_message=mock_progress_message_instance)
    assert passed_progress_msg_to_callback_quick_on is mock_progress_message_instance # Check captured arg
    mock_progress_message_instance.stream_token.assert_any_call("\n✨ InsightFlow processing complete!")
    mock_progress_message_instance.update.assert_called_once()

    # Original Assertions
    mock_graph_ainvoke.assert_called_once()
    called_state = mock_graph_ainvoke.call_args[0][0]
    assert called_state["selected_personas"] == expected_quick_mode_personas
    assert called_state["query"] == mock_incoming_message.content
    mock_cl.user_session.get.assert_any_call("direct_mode")
    mock_cl.user_session.get.assert_any_call("quick_mode", False)
    mock_cl.user_session.get.assert_any_call("insight_flow_state")

@pytest.mark.asyncio
async def test_on_message_quick_mode_off_uses_original_personas(mock_cl):
    """Test that on_message with quick_mode=False uses original selected_personas for invoke_langgraph."""
    mock_incoming_message = MagicMock(spec=cl.Message)
    mock_incoming_message.content = "Test query for quick mode off"

    original_selected_personas = ["original1", "original2"]
    initial_state_dict_off = {
        "query": "", "selected_personas": original_selected_personas, "persona_responses": {},
        "synthesized_response": None, "visualization_code": None,
        "visualization_image_url": None, "current_step_name": "awaiting_query",
        "error_message": None, "panel_type": "research"
    }
    
    setup_mock_cl_session_get(mock_cl, direct_mode_value=False, quick_mode_value=False, initial_state_dict=initial_state_dict_off)
    
    mock_progress_message_instance_off = AsyncMock(spec=cl.Message)
    mock_progress_message_instance_off.send = AsyncMock()
    mock_progress_message_instance_off.stream_token = AsyncMock()
    mock_progress_message_instance_off.update = AsyncMock()

    default_mock_other_message_instance_off = AsyncMock(spec=cl.Message, send=AsyncMock())
    def message_side_effect_quick_off(*args, **kwargs):
        if kwargs.get("content") == "": 
            return mock_progress_message_instance_off
        return default_mock_other_message_instance_off
    mock_cl.Message.side_effect = message_side_effect_quick_off
    mock_cl.Message.reset_mock()

    passed_progress_msg_to_callback_quick_off = None
    # Define a unique local mock handler class for this test
    class MockCBHandlerQuickOffLocal(app.InsightFlowCallbackHandler): # Changed name
        def __init__(self, progress_message: cl.Message):
            nonlocal passed_progress_msg_to_callback_quick_off
            passed_progress_msg_to_callback_quick_off = progress_message
            # super().__init__(progress_message) # Not strictly necessary
            self.progress_message = progress_message

    def mock_ifch_constructor_side_effect_quick_off(progress_message):
        return MockCBHandlerQuickOffLocal(progress_message=progress_message)

    with patch(f'{APP_PY_PATH}.cl', new=mock_cl):
        with patch(f'{APP_PY_PATH}.insight_flow_graph.ainvoke', new_callable=AsyncMock) as mock_graph_ainvoke:
            # Patch InsightFlowCallbackHandler with a MagicMock whose side_effect instantiates the local mock
            with patch(f'{APP_PY_PATH}.InsightFlowCallbackHandler') as mock_ifch_constructor_quick_off:
                mock_ifch_constructor_quick_off.side_effect = mock_ifch_constructor_side_effect_quick_off
                from app import on_message
                await on_message(mock_incoming_message)

    # Progress Message Assertions
    mock_cl.Message.assert_any_call(content="")
    mock_progress_message_instance_off.send.assert_called_once()
    mock_progress_message_instance_off.stream_token.assert_any_call("⏳ Initializing InsightFlow process...")
    mock_cl.user_session.set.assert_any_call("progress_msg", mock_progress_message_instance_off)
    # Assert on the MagicMock that replaced the class constructor
    mock_ifch_constructor_quick_off.assert_called_once_with(progress_message=mock_progress_message_instance_off)
    assert passed_progress_msg_to_callback_quick_off is mock_progress_message_instance_off # Check captured arg
    mock_progress_message_instance_off.stream_token.assert_any_call("\n✨ InsightFlow processing complete!")
    mock_progress_message_instance_off.update.assert_called_once()

    # Original Assertions
    mock_graph_ainvoke.assert_called_once()
    called_state = mock_graph_ainvoke.call_args[0][0]
    assert called_state["selected_personas"] == original_selected_personas
    assert called_state["query"] == mock_incoming_message.content
    mock_cl.user_session.get.assert_any_call("direct_mode")
    mock_cl.user_session.get.assert_any_call("quick_mode", False)
    mock_cl.user_session.get.assert_any_call("insight_flow_state")