Spaces:
Running
Running
File size: 4,819 Bytes
31add3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import SystemMessage, HumanMessage, AIMessageChunk
# Assuming utils.persona is in the Python path
from utils.persona import PersonaReasoning, PersonaFactory
import inspect
print(f"DEBUG: PersonaFactory imported from: {inspect.getfile(PersonaFactory)}")
print(f"DEBUG: PersonaReasoning imported from: {inspect.getfile(PersonaReasoning)}")
@pytest.fixture
def mock_llm():
"""Fixture for a mocked LLM instance."""
llm = MagicMock(spec=BaseChatModel)
# Mock the behavior of astream
async def mock_astream_behavior(messages):
# Simulate streaming chunks
yield AIMessageChunk(content="Hello, ")
yield AIMessageChunk(content="world!")
# Simulate an empty chunk which can happen
yield AIMessageChunk(content="")
yield AIMessageChunk(content=" How are you?")
llm.astream = MagicMock(side_effect=mock_astream_behavior)
return llm
@pytest.mark.asyncio
async def test_persona_reasoning_generate_perspective(mock_llm):
"""Test that PersonaReasoning.generate_perspective calls its LLM correctly and returns aggregated content."""
persona_id = "test_persona"
name = "Test Persona"
system_prompt = "You are a test persona."
reasoning = PersonaReasoning(persona_id, name, system_prompt, mock_llm)
query = "What is the meaning of life?"
expected_response = "Hello, world! How are you?"
actual_response = await reasoning.generate_perspective(query)
# Verify LLM call
mock_llm.astream.assert_called_once()
call_args = mock_llm.astream.call_args[0][0] # Get the first positional argument (messages list)
assert len(call_args) == 2
assert isinstance(call_args[0], SystemMessage)
assert call_args[0].content == system_prompt
assert isinstance(call_args[1], HumanMessage)
assert call_args[1].content == query
# Verify aggregated response
assert actual_response == expected_response
def test_persona_factory_initialization():
"""Test PersonaFactory initialization and config loading."""
factory = PersonaFactory()
assert len(factory.persona_configs) > 0 # Check that some configs are loaded
assert "analytical" in factory.persona_configs
assert factory.persona_configs["analytical"]["name"] == "Analytical"
def test_persona_factory_create_persona_success(mock_llm):
"""Test successful creation of a PersonaReasoning instance."""
factory = PersonaFactory()
persona_id = "analytical"
persona_instance = factory.create_persona(persona_id, mock_llm)
assert persona_instance is not None
assert isinstance(persona_instance, PersonaReasoning)
assert persona_instance.persona_id == persona_id
assert persona_instance.name == factory.persona_configs[persona_id]["name"]
assert persona_instance.system_prompt == factory.persona_configs[persona_id]["system_prompt"]
assert persona_instance.llm == mock_llm
def test_persona_factory_create_persona_invalid_id(mock_llm):
"""Test creating a persona with an invalid ID returns None."""
factory = PersonaFactory()
persona_instance = factory.create_persona("non_existent_persona", mock_llm)
assert persona_instance is None
def test_persona_factory_create_persona_no_llm():
"""Test creating a persona without an LLM instance returns None."""
factory = PersonaFactory()
# We need a way to pass a 'None' LLM or ensure BaseChatModel type hint isn't violated
# For now, let's assume the type hint means it must be a BaseChatModel.
# The implementation checks `if config and llm_instance:`
# So passing a non-BaseChatModel or None should ideally be handled by create_persona.
# Let's test with None if the type hint allows, or by how create_persona handles it.
# The implementation prints an error if llm_instance is None, and returns None.
# Patch print to check for the error message if desired, but for now, just check None return
with patch('utils.persona.base.print') as mock_print: # Patched print in the correct module
persona_instance = factory.create_persona("analytical", None) # Pass None for LLM
assert persona_instance is None
mock_print.assert_any_call("DEBUG Error: LLM instance not provided for persona analytical")
def test_get_available_personas():
"""Test that get_available_personas returns the expected dictionary."""
factory = PersonaFactory()
available = factory.get_available_personas()
assert isinstance(available, dict)
assert "analytical" in available
assert available["analytical"] == "Analytical"
assert len(available) == len(factory.persona_configs) |