Spaces:
Running
Running
import pytest | |
from unittest.mock import patch, MagicMock, AsyncMock | |
import chainlit as cl # Import chainlit for its types | |
# Assuming your app structure allows these imports | |
# You might need to adjust paths or ensure __init__.py files are present | |
from app import ( | |
on_chat_start, | |
_apply_chat_settings_to_state, | |
execute_persona_tasks, | |
InsightFlowState, | |
PersonaFactory, # Assuming PersonaFactory is accessible or can be mocked | |
PERSONA_LLM_MAP, # Assuming this is accessible or can be mocked | |
RAG_ENABLED_PERSONA_IDS # If this is defined globally in app.py | |
) | |
from utils.persona.base import PersonaReasoning # For creating mock persona objects | |
# Default RAG_ENABLED_PERSONA_IDS from app.py (or import it if it's moved to a config) | |
APP_RAG_ENABLED_PERSONA_IDS = ["analytical", "philosophical", "metaphorical"] | |
def mock_cl_user_session(): | |
"""Mocks chainlit.user_session with a dictionary-like store.""" | |
with patch('chainlit.user_session', new_callable=MagicMock) as mock_session: | |
_session_store = {} | |
def get_item(key, default=None): | |
return _session_store.get(key, default) | |
def set_item(key, value): | |
_session_store[key] = value | |
def contains_item(key): | |
# Check if key is in _session_store, not mock_session itself | |
return key in _session_store | |
mock_session.get = MagicMock(side_effect=get_item) | |
mock_session.set = MagicMock(side_effect=set_item) | |
# To make `key in cl.user_session` work as expected: | |
mock_session.__contains__ = MagicMock(side_effect=contains_item) | |
# Initialize with some common defaults if needed by other parts of app logic | |
_session_store["id"] = "test_session_id" | |
_session_store["persona_factory"] = PersonaFactory() # Real or mocked | |
_session_store["embedding_model"] = MagicMock() # Assume it's initialized | |
yield mock_session | |
_session_store.clear() # Clean up after test | |
def mock_persona_factory(): | |
factory = MagicMock(spec=PersonaFactory) | |
# Setup mock personas that the factory can create | |
mock_analytical_persona = MagicMock(spec=PersonaReasoning) | |
mock_analytical_persona.id = "analytical" | |
mock_analytical_persona.name = "Analytical Persona" | |
mock_analytical_persona.expertise = "Logic and reason" | |
mock_analytical_persona.role = "To analyze deeply" | |
mock_analytical_persona.generate_perspective = AsyncMock(return_value="Analytical perspective") | |
mock_philosophical_persona = MagicMock(spec=PersonaReasoning) | |
mock_philosophical_persona.id = "philosophical" | |
mock_philosophical_persona.name = "Philosophical Persona" | |
mock_philosophical_persona.expertise = "Wisdom and ethics" | |
mock_philosophical_persona.role = "To ponder deeply" | |
mock_philosophical_persona.generate_perspective = AsyncMock(return_value="Philosophical perspective") | |
mock_scientific_persona = MagicMock(spec=PersonaReasoning) # Non-RAG example | |
mock_scientific_persona.id = "scientific" | |
mock_scientific_persona.name = "Scientific Persona" | |
mock_scientific_persona.expertise = "Empirical evidence" | |
mock_scientific_persona.role = "To investigate phenomena" | |
mock_scientific_persona.generate_perspective = AsyncMock(return_value="Scientific perspective") | |
def create_persona_side_effect(persona_id, llm): | |
if persona_id == "analytical": return mock_analytical_persona | |
if persona_id == "philosophical": return mock_philosophical_persona | |
if persona_id == "scientific": return mock_scientific_persona | |
return MagicMock(spec=PersonaReasoning, id=persona_id, name=f"{persona_id.capitalize()} Persona", generate_perspective=AsyncMock(return_value=f"{persona_id} perspective")) | |
factory.create_persona = MagicMock(side_effect=create_persona_side_effect) | |
# Mock get_available_personas if _apply_chat_settings_to_state uses it directly | |
factory.get_available_personas = MagicMock(return_value={"analytical": "Analytical", "philosophical": "Philosophical", "scientific": "Scientific"}) | |
factory.persona_configs = { # For on_chat_start to build switches | |
"analytical": {"name": "Analytical Persona"}, | |
"philosophical": {"name": "Philosophical Persona"}, | |
"scientific": {"name": "Scientific Persona"} | |
} | |
return factory | |
async def test_on_chat_start_adds_rag_toggle(mock_cl_user_session, mock_persona_factory): | |
"""Test that on_chat_start includes the 'enable_rag' Switch.""" | |
mock_cl_user_session.get.side_effect = lambda key, default=None: mock_persona_factory if key == "persona_factory" else (_session_store.get(key,default) if '_session_store' in globals() else default) # Ensure factory is returned for this test | |
# Configure the mock for cl.ChatSettings | |
mock_chat_settings_instance = AsyncMock() # This will be returned by cl.ChatSettings() | |
mock_chat_settings_instance.send = AsyncMock() # Ensure the send method is an AsyncMock | |
with patch('app.initialize_configurations') as mock_init_config, \ | |
patch('app.PersonaFactory', return_value=mock_persona_factory), \ | |
patch('chainlit.ChatSettings', return_value=mock_chat_settings_instance) as mock_chat_settings_class, \ | |
patch('chainlit.Message') as mock_cl_message, \ | |
patch('app.get_embedding_model', return_value=MagicMock()): # Mock embedding model loading | |
await on_chat_start() | |
mock_chat_settings_class.assert_called_once() # Check that cl.ChatSettings class was called | |
# mock_chat_settings_instance.send.assert_called_once() # Check that the send method on the instance was called | |
args, _ = mock_chat_settings_class.call_args | |
inputs_list = args[0] # The first positional argument is `inputs` | |
rag_toggle_present = any(widget.id == "enable_rag" and isinstance(widget, cl.Switch) for widget in inputs_list) | |
assert rag_toggle_present, "'enable_rag' Switch not found in ChatSettings" | |
rag_toggle_widget = next(widget for widget in inputs_list if widget.id == "enable_rag") | |
assert rag_toggle_widget.initial is True, "RAG toggle should be ON by default" | |
async def test_apply_chat_settings_reads_rag_toggle(mock_cl_user_session, mock_persona_factory): | |
"""Test that _apply_chat_settings_to_state correctly reads and sets the RAG toggle.""" | |
mock_cl_user_session.get.side_effect = [ | |
{"enable_rag": False, "selected_team": "none"}, # First call from _apply_chat_settings_to_state for chat_settings | |
InsightFlowState(selected_personas=[]), # For insight_flow_state | |
mock_persona_factory, # For persona_factory | |
False, # Subsequent get for enable_rag from session | |
] | |
# More robust: use a dict for side_effect or a more stateful mock for cl.user_session.get | |
# This simple list relies on call order, which is fragile. | |
# Let's refine using the session_store approach in the fixture itself for the general case | |
# For this specific test, we can control the return values for chat_settings specifically. | |
_session_store_for_test = { | |
"insight_flow_state": InsightFlowState(selected_personas=[]), | |
"persona_factory": mock_persona_factory, | |
"chat_settings": {"enable_rag": False, "selected_team": "none"} # Simulate UI sending False | |
} | |
def mock_get_specific(key, default=None): | |
if key == "chat_settings": return _session_store_for_test["chat_settings"] | |
if key == "insight_flow_state": return _session_store_for_test["insight_flow_state"] | |
if key == "persona_factory": return _session_store_for_test["persona_factory"] | |
return _session_store_for_test.get(key, default) | |
mock_cl_user_session.get = MagicMock(side_effect=mock_get_specific) | |
await _apply_chat_settings_to_state() | |
# Check that cl.user_session.set was called to update 'enable_rag' | |
# Find the call to set 'enable_rag' | |
set_rag_call = next((call for call in mock_cl_user_session.set.call_args_list if call[0][0] == 'enable_rag'), None) | |
assert set_rag_call is not None, "cl.user_session.set was not called for 'enable_rag'" | |
assert set_rag_call[0][1] is False, "'enable_rag' in session should be False" | |
# Mock the RAG context function | |
async def test_execute_persona_tasks_rag_toggle_on(mock_get_context, mock_cl_user_session, mock_persona_factory): | |
"""Test execute_persona_tasks: RAG attempted when toggle is ON for RAG-enabled persona.""" | |
mock_cl_user_session.set("enable_rag", True) | |
mock_cl_user_session.set("persona_factory", mock_persona_factory) | |
mock_cl_user_session.set("embedding_model", MagicMock()) # Mocked embedding model | |
mock_cl_user_session.set("progress_msg", AsyncMock(spec=cl.Message)) # Mock progress message | |
mock_cl_user_session.set("completed_steps_log", []) | |
# Mock PERSONA_LLM_MAP or ensure it's correctly populated for 'analytical' | |
with patch.dict(PERSONA_LLM_MAP, {"analytical": MagicMock(spec=ChatOpenAI)}, clear=True): | |
initial_state = InsightFlowState( | |
query="test query for analytical", | |
selected_personas=["analytical"], # RAG-enabled persona | |
persona_responses={} | |
) | |
mock_get_context.return_value = "Retrieved RAG context." | |
final_state = await execute_persona_tasks(initial_state) | |
mock_get_context.assert_called_once_with("test query for analytical", "analytical", mock_cl_user_session.get("embedding_model")) | |
# Check if the persona's generate_perspective was called with an augmented prompt | |
analytical_persona_mock = mock_persona_factory.create_persona("analytical", None) # Get the mock | |
call_args = analytical_persona_mock.generate_perspective.call_args[0][0] | |
assert "Retrieved RAG context:" in call_args | |
assert "User Query: test query for analytical" in call_args | |
assert final_state["persona_responses"]["analytical"] == "Analytical perspective" | |
async def test_execute_persona_tasks_rag_toggle_off(mock_get_context, mock_cl_user_session, mock_persona_factory): | |
"""Test execute_persona_tasks: RAG NOT attempted when toggle is OFF.""" | |
mock_cl_user_session.set("enable_rag", False) # RAG is OFF | |
mock_cl_user_session.set("persona_factory", mock_persona_factory) | |
mock_cl_user_session.set("embedding_model", MagicMock()) | |
mock_cl_user_session.set("progress_msg", AsyncMock(spec=cl.Message)) | |
mock_cl_user_session.set("completed_steps_log", []) | |
with patch.dict(PERSONA_LLM_MAP, {"analytical": MagicMock(spec=ChatOpenAI)}, clear=True): | |
initial_state = InsightFlowState( | |
query="test query for analytical rag off", | |
selected_personas=["analytical"], | |
persona_responses={} | |
) | |
final_state = await execute_persona_tasks(initial_state) | |
mock_get_context.assert_not_called() | |
# Check if the persona's generate_perspective was called with the original query | |
analytical_persona_mock = mock_persona_factory.create_persona("analytical", None) | |
call_args = analytical_persona_mock.generate_perspective.call_args[0][0] | |
assert "Retrieved RAG context:" not in call_args # Original prompt structure without RAG context part | |
assert "User Query: test query for analytical rag off" in call_args # Should be original query or a non-RAG augmented one | |
assert "No specific context from your knowledge base was retrieved" in call_args # Or check for the non-RAG prompt | |
assert final_state["persona_responses"]["analytical"] == "Analytical perspective" | |
async def test_execute_persona_tasks_rag_on_non_rag_persona(mock_get_context, mock_cl_user_session, mock_persona_factory): | |
"""Test execute_persona_tasks: RAG NOT attempted for non-RAG-enabled persona even if toggle is ON.""" | |
mock_cl_user_session.set("enable_rag", True) # RAG is ON globally | |
mock_cl_user_session.set("persona_factory", mock_persona_factory) | |
mock_cl_user_session.set("embedding_model", MagicMock()) | |
mock_cl_user_session.set("progress_msg", AsyncMock(spec=cl.Message)) | |
mock_cl_user_session.set("completed_steps_log", []) | |
with patch.dict(PERSONA_LLM_MAP, {"scientific": MagicMock(spec=ChatOpenAI)}, clear=True): | |
initial_state = InsightFlowState( | |
query="test query for scientific", | |
selected_personas=["scientific"], # NON-RAG-enabled persona | |
persona_responses={} | |
) | |
final_state = await execute_persona_tasks(initial_state) | |
mock_get_context.assert_not_called() | |
scientific_persona_mock = mock_persona_factory.create_persona("scientific", None) | |
call_args = scientific_persona_mock.generate_perspective.call_args[0][0] | |
assert "Retrieved RAG context:" not in call_args | |
assert call_args == "test query for scientific" # Original query passed directly | |
assert final_state["persona_responses"]["scientific"] == "Scientific perspective" | |
# Need to import ChatOpenAI for the PERSONA_LLM_MAP patching to work if it's type hinted. | |
from langchain_openai import ChatOpenAI |