File size: 9,702 Bytes
124b5b5
9cf5fee
124b5b5
9e707a5
b089011
9e707a5
6f0ac93
 
 
 
 
9cf5fee
 
 
9e707a5
 
 
 
 
 
 
 
 
b089011
 
 
 
9e707a5
 
6f0ac93
 
9e707a5
 
 
 
 
b089011
 
 
9e707a5
b089011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cf5fee
b089011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cf5fee
9e707a5
 
 
b089011
 
 
9e707a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b089011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e707a5
9cf5fee
 
 
 
 
124b5b5
 
9cf5fee
 
 
 
9e707a5
9cf5fee
6f0ac93
 
9e707a5
 
 
 
 
 
 
 
9cf5fee
 
9e707a5
 
 
 
 
b089011
9e707a5
 
 
 
b089011
6f0ac93
 
 
 
b089011
 
 
 
 
 
 
 
 
 
 
 
6f0ac93
b089011
 
 
 
 
 
6f0ac93
b089011
 
 
6f0ac93
 
b089011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f0ac93
 
b089011
6f0ac93
b089011
 
 
 
 
 
 
124b5b5
9e707a5
6f0ac93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import numpy as np  # Import NumPy first
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import time
import os
import re
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class SummariserService:
    def __init__(self):
        # Status tracking
        self.model_loading_status = {
            "is_loading": False,
            "step": "",
            "progress": 0
        }

        # Consider these alternative models
        model_options = {
            "general": "facebook/bart-large-cnn",
            "news": "facebook/bart-large-xsum",
            "long_form": "google/pegasus-large",
            "literary": "t5-large"
        }

        # Choose the most appropriate model - BART works better for web content
        model_name = model_options["general"]  # Use BART for better web content summarization

        # Update loading status
        self.model_loading_status["is_loading"] = True
        self.model_loading_status["step"] = "Initializing tokenizer"

        # Ensure cache directory exists and is writable
        cache_dir = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface_cache")
        os.makedirs(cache_dir, exist_ok=True)

        try:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                cache_dir=cache_dir,
                local_files_only=False
            )

            self.model_loading_status["step"] = "Loading model"
            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                model_name,
                cache_dir=cache_dir,
                force_download=False,
                local_files_only=False
            )

            # Move to GPU if available
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.model.to(self.device)

        except Exception as e:
            # Fallback to a smaller model if the main one fails
            print(f"Error loading model {model_name}: {str(e)}")
            print("Falling back to smaller model...")

            fallback_model = "sshleifer/distilbart-cnn-6-6"  # Much smaller model

            self.tokenizer = AutoTokenizer.from_pretrained(
                fallback_model,
                cache_dir=cache_dir,
                local_files_only=False
            )

            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                fallback_model,
                cache_dir=cache_dir,
                force_download=False,
                local_files_only=False
            )

            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.model.to(self.device)

            # Update model name for metadata
            model_name = fallback_model

        self.model_loading_status["is_loading"] = False
        self.model_loading_status["progress"] = 100

        # Store the actual model name used
        self.model_name = model_name

        # Track current processing job
        self.current_job = {
            "in_progress": False,
            "start_time": None,
            "input_word_count": 0,
            "estimated_time": 0,
            "stage": "",
            "progress": 0
        }

    def clean_summary(self, summary):
        """Clean and format the summary text"""
        # Remove any leading punctuation or spaces
        summary = re.sub(r'^[,.\s]+', '', summary)

        # Ensure the first letter is capitalized
        if summary and len(summary) > 0:
            summary = summary[0].upper() + summary[1:]

        # Ensure proper ending punctuation
        if summary and not any(summary.endswith(end) for end in ['.', '!', '?']):
            last_sentence_end = max(
                summary.rfind('.'),
                summary.rfind('!'),
                summary.rfind('?')
            )
            if last_sentence_end > 0:
                summary = summary[:last_sentence_end + 1]
            else:
                summary = summary + '.'

        return summary

    def get_status(self):
        """Return the current status of the summarizer service"""
        status = {
            "model_loading": self.model_loading_status,
            "device": self.device,
            "current_job": self.current_job
        }

        # Update estimated time remaining if job in progress
        if self.current_job["in_progress"] and self.current_job["start_time"]:
            elapsed = time.time() - self.current_job["start_time"]
            estimated = self.current_job["estimated_time"]
            remaining = max(0, estimated - elapsed)
            status["current_job"]["time_remaining"] = round(remaining, 1)

            # Update progress based on time
            if estimated > 0:
                progress = min(95, (elapsed / estimated) * 100)
                status["current_job"]["progress"] = round(progress, 0)

        return status

    def summarise(self, text, max_length=250, min_length=100, do_sample=True, temperature=1.2):
        """
        Summarise the given text using the loaded model.

        Args:
            text (str): The text to summarise
            max_length (int): Maximum length of the summary in characters
            min_length (int): Minimum length of the summary in characters
            do_sample (bool): Whether to use sampling for generation
            temperature (float): Sampling temperature (higher = more random)

        Returns:
            dict: The generated summary and processing metadata
        """
        logger.info(f"Starting summarization of text with {len(text)} characters")

        # Reset and start job tracking
        self.current_job = {
            "in_progress": True,
            "start_time": time.time(),
            "input_word_count": len(text.split()),
            "estimated_time": max(1, min(30, len(text.split()) / 500)),  # Rough estimate
            "stage": "Tokenizing input text",
            "progress": 5
        }

        result = {
            "summary": "",
            "metadata": {
                "input_word_count": self.current_job["input_word_count"],
                "estimated_time_seconds": self.current_job["estimated_time"],
                "model_used": self.model_name,
                "processing_device": self.device
            }
        }

        try:
            # Preprocess the text to focus on main content
            text = self.preprocess_text(text)
            logger.info(f"After preprocessing: {len(text)} characters")

            # Tokenization step
            inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
            input_ids = inputs.input_ids.to(self.device)

            # Update metadata with token info
            result["metadata"]["input_token_count"] = len(input_ids[0])
            result["metadata"]["truncated"] = len(input_ids[0]) == 1024

            # Update job status
            self.current_job["stage"] = "Generating summary"
            self.current_job["progress"] = 30

            # Enhanced generation parameters for better web content summarization
            summary_ids = self.model.generate(
                input_ids,
                max_length=max_length,
                min_length=min_length,
                do_sample=do_sample,
                temperature=temperature,
                num_beams=5,  # Increased from 4 to 5
                early_stopping=True,
                no_repeat_ngram_size=3,
                length_penalty=2.0,
                top_k=50,  # Added for better quality
                top_p=0.95,  # Added for better quality
            )

            # Update job status
            self.current_job["stage"] = "Post-processing summary"
            self.current_job["progress"] = 90

            summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)

            # Clean and format the summary
            summary = self.clean_summary(summary)

            result["summary"] = summary
            result["metadata"]["output_word_count"] = len(summary.split())
            result["metadata"]["compression_ratio"] = round(len(summary.split()) / self.current_job["input_word_count"] * 100, 1)

            logger.info(f"Generated summary with {len(summary)} characters")

        except Exception as e:
            logger.error(f"Error during summarization: {str(e)}")
            result["summary"] = "An error occurred during summarization. Please try again with a shorter text or different parameters."
            result["error"] = str(e)
        finally:
            # Complete job
            self.current_job["in_progress"] = False
            self.current_job["stage"] = "Complete"
            self.current_job["progress"] = 100

        return result

    def preprocess_text(self, text):
        """Preprocess text to improve summarization quality."""
        # Remove excessive whitespace
        text = re.sub(r'\s+', ' ', text)

        # Remove common web page boilerplate text
        text = re.sub(r'Skip to (content|main).*?»', '', text)
        text = re.sub(r'Search for:.*?Search', '', text)
        text = re.sub(r'Menu.*?Resources', '', text, flags=re.DOTALL)

        # Remove comment sections (often start with phrases like "X responses to")
        text = re.sub(r'\d+ responses to.*?$', '', text, flags=re.DOTALL)

        # Remove form fields and subscription prompts
        text = re.sub(r'(Your email address will not be published|Required fields are marked).*?$', '', text, flags=re.DOTALL)

        # Focus on the first part of very long texts (likely the main content)
        if len(text) > 10000:
            text = text[:10000]

        return text