File size: 7,528 Bytes
31add3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import pytest
from unittest.mock import AsyncMock, MagicMock
from openai import AsyncOpenAI
from openai.types.images_response import Image, ImagesResponse # For mocking response
from typing import Optional

from utils.visualization_utils import generate_dalle_image

# Placeholder for the actual function for now, so the test file can be written
# async def generate_dalle_image(prompt: str, client: AsyncOpenAI) -> Optional[str]:
#     pass

@pytest.mark.asyncio
async def test_generate_dalle_image_success():
    """Test successful DALL-E image generation."""
    mock_client = AsyncMock(spec=AsyncOpenAI)
    expected_url = "https://example.com/image.png"
    
    # Mocking the response structure from OpenAI DALL-E API
    # It returns an ImagesResponse object, which has a 'data' attribute (list of Image objects)
    # Each Image object has a 'url' attribute.
    mock_image = Image(b64_json=None, revised_prompt=None, url=expected_url)
    mock_images_response = ImagesResponse(created=1234567890, data=[mock_image])
    
    mock_client.images.generate = AsyncMock(return_value=mock_images_response)
    
    prompt_text = "A beautiful sunset over mountains, hand-drawn sketch style"
    image_url = await generate_dalle_image(prompt_text, mock_client)
    
    mock_client.images.generate.assert_called_once_with(
        prompt=prompt_text,
        model="dall-e-3",
        size="1024x1024",
        quality="standard", # DALL-E 3 only supports 'standard' and 'hd'
        n=1,
        style="vivid" # 'vivid' or 'natural'. 'vivid' for hyper-real, 'natural' for less so.
                      # The "hand-drawn" aspect is best achieved via prompt engineering.
    )
    assert image_url == expected_url

@pytest.mark.asyncio
async def test_generate_dalle_image_api_error():
    """Test DALL-E image generation when API returns an error."""
    mock_client = AsyncMock(spec=AsyncOpenAI)
    mock_client.images.generate = AsyncMock(side_effect=Exception("API Error"))
    
    prompt_text = "A complex abstract concept"
    image_url = await generate_dalle_image(prompt_text, mock_client)
    
    mock_client.images.generate.assert_called_once()
    assert image_url is None

@pytest.mark.asyncio
async def test_generate_dalle_image_no_data_in_response():
    """Test DALL-E image generation when API response has no data."""
    mock_client = AsyncMock(spec=AsyncOpenAI)
    # Simulate a response where the 'data' list is empty or None
    mock_images_response_no_data = ImagesResponse(created=1234567890, data=[])
    mock_client.images.generate = AsyncMock(return_value=mock_images_response_no_data)
    
    prompt_text = "A rare scenario"
    image_url = await generate_dalle_image(prompt_text, mock_client)
    
    mock_client.images.generate.assert_called_once()
    assert image_url is None

@pytest.mark.asyncio
async def test_generate_dalle_image_no_url_in_image_object():
    """Test DALL-E image generation when Image object in response has no URL."""
    mock_client = AsyncMock(spec=AsyncOpenAI)
    # Simulate a response where the Image object has no 'url'
    mock_image_no_url = Image(b64_json=None, revised_prompt=None, url=None) # type: ignore
    mock_images_response_no_url = ImagesResponse(created=1234567890, data=[mock_image_no_url])

    mock_client.images.generate = AsyncMock(return_value=mock_images_response_no_url)
    
    prompt_text = "Another rare scenario"
    image_url = await generate_dalle_image(prompt_text, mock_client)
    
    mock_client.images.generate.assert_called_once()
    assert image_url is None

# --- Tests for generate_mermaid_code --- #
from langchain_openai import ChatOpenAI # For type hinting the mock LLM client
from langchain_core.messages import AIMessage # For mocking LLM response

# Placeholder for the function to be tested
# from utils.visualization_utils import generate_mermaid_code 

MERMAID_SYSTEM_PROMPT_TEMPLATE = """You are an expert in creating Mermaid diagrams. Based on the following text, generate a concise and accurate Mermaid diagram syntax. Only output the Mermaid code block (```mermaid\n...
```). Do not include any other explanatory text. If the text cannot be reasonably converted to a diagram, output '// No suitable diagram' as a comment. Text: {text_input}
"""

@pytest.mark.asyncio
async def test_generate_mermaid_code_success():
    """Test successful Mermaid code generation."""
    mock_llm_client = AsyncMock(spec=ChatOpenAI)
    expected_mermaid_code = "graph TD;\n    A[Start] --> B{Is it?};\n    B -- Yes --> C[End];\n    B -- No --> D[Alternative End];"
    
    # Mock the response from the LLM client's ainvoke method
    mock_llm_response = AIMessage(content=f"```mermaid\n{expected_mermaid_code}\n```")
    mock_llm_client.ainvoke = AsyncMock(return_value=mock_llm_response)
    
    input_text = "Describe a simple decision process."
    
    # Dynamically import generate_mermaid_code to ensure mocks are applied if it uses module-level things
    from utils.visualization_utils import generate_mermaid_code
    mermaid_output = await generate_mermaid_code(input_text, mock_llm_client)
    
    expected_prompt = MERMAID_SYSTEM_PROMPT_TEMPLATE.format(text_input=input_text)
    # ainvoke is called with a list of messages or a string. Let's assume a list with a SystemMessage for now.
    # We will need to check the actual implementation of generate_mermaid_code to refine this assertion.
    # For now, let's assume it sends a list of messages, and the first one contains the prompt.
    # Or, if it sends a string directly to ainvoke, we adapt.
    # Based on typical Langchain usage with ChatModels, it's usually a list of Messages.
    
    # Check that ainvoke was called. For the prompt content, we'll check the first message's content.
    mock_llm_client.ainvoke.assert_called_once()
    called_messages = mock_llm_client.ainvoke.call_args[0][0]
    assert len(called_messages) > 0
    assert called_messages[0].content == expected_prompt # Assuming SystemMessage is the first
    
    assert mermaid_output == expected_mermaid_code

@pytest.mark.asyncio
async def test_generate_mermaid_code_llm_error():
    """Test Mermaid code generation when LLM call fails."""
    mock_llm_client = AsyncMock(spec=ChatOpenAI)
    mock_llm_client.ainvoke = AsyncMock(side_effect=Exception("LLM API Error"))
    
    input_text = "Some complex text that might cause an error."
    from utils.visualization_utils import generate_mermaid_code
    mermaid_output = await generate_mermaid_code(input_text, mock_llm_client)
    
    mock_llm_client.ainvoke.assert_called_once()
    assert mermaid_output is None # Expect None or a specific error string

@pytest.mark.asyncio
async def test_generate_mermaid_code_no_suitable_diagram():
    """Test Mermaid code generation when LLM indicates no suitable diagram."""
    mock_llm_client = AsyncMock(spec=ChatOpenAI)
    # LLM returns the specific comment indicating no diagram
    mock_llm_response = AIMessage(content="// No suitable diagram")
    mock_llm_client.ainvoke = AsyncMock(return_value=mock_llm_response)
    
    input_text = "This text is not suitable for a diagram."
    from utils.visualization_utils import generate_mermaid_code
    mermaid_output = await generate_mermaid_code(input_text, mock_llm_client)
    
    mock_llm_client.ainvoke.assert_called_once()
    # Depending on implementation, it might return None or the comment itself.
    # For now, let's assume it should return None if it sees this specific comment.
    assert mermaid_output is None