|
import pytest |
|
from unittest.mock import patch, MagicMock |
|
import sys |
|
import os |
|
|
|
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
from app.services.summariser import SummariserService |
|
|
|
|
|
def test_summariser_with_mock(): |
|
|
|
with patch('app.services.summariser.AutoTokenizer') as mock_tokenizer_class, \ |
|
patch('app.services.summariser.AutoModelForSeq2SeqLM') as mock_model_class: |
|
|
|
|
|
mock_tokenizer = MagicMock() |
|
mock_tokenizer.decode.return_value = "This is a test summary." |
|
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer |
|
|
|
|
|
mock_model = MagicMock() |
|
mock_model.generate.return_value = [[1, 2, 3, 4]] |
|
mock_model.to.return_value = mock_model |
|
mock_model_class.from_pretrained.return_value = mock_model |
|
|
|
|
|
summariser = SummariserService() |
|
|
|
|
|
text = "This is a test paragraph that should be summarized." |
|
result = summariser.summarise(text, max_length=50, min_length=10) |
|
|
|
|
|
assert isinstance(result, dict) |
|
assert "summary" in result |
|
assert result["summary"] == "This is a test summary." |
|
assert "metadata" in result |
|
|
|
|
|
mock_tokenizer_class.from_pretrained.assert_called_once() |
|
mock_model_class.from_pretrained.assert_called_once() |
|
mock_model.generate.assert_called_once() |
|
mock_tokenizer.decode.assert_called_once() |
|
|
|
|
|
def test_summariser(): |
|
summariser = SummariserService() |
|
text = "This is a test paragraph that should be summarized. It contains multiple sentences with different information. The summarizer should extract the key points and generate a concise summary." |
|
|
|
|
|
|
|
|
|
result = summariser.summarise(text, max_length=50, min_length=10) |
|
|
|
|
|
assert isinstance(result, dict) |
|
assert "summary" in result |
|
assert isinstance(result["summary"], str) |
|
assert "metadata" in result |
|
|
|
|
|
summary = result["summary"] |
|
|
|
|
|
|
|
|
|
assert len(summary) > 0 |
|
|
|
|
|
if summary != text: |
|
assert len(summary) < len(text) * 0.8 |
|
|