|
import numpy as np |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import time |
|
import os |
|
import re |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class SummariserService: |
|
def __init__(self): |
|
|
|
self.model_loading_status = { |
|
"is_loading": False, |
|
"step": "", |
|
"progress": 0 |
|
} |
|
|
|
|
|
model_options = { |
|
"general": "facebook/bart-large-cnn", |
|
"news": "facebook/bart-large-xsum", |
|
"long_form": "google/pegasus-large", |
|
"literary": "t5-large" |
|
} |
|
|
|
|
|
model_name = model_options["general"] |
|
|
|
|
|
self.model_loading_status["is_loading"] = True |
|
self.model_loading_status["step"] = "Initializing tokenizer" |
|
|
|
|
|
cache_dir = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface_cache") |
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
try: |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
cache_dir=cache_dir, |
|
local_files_only=False |
|
) |
|
|
|
self.model_loading_status["step"] = "Loading model" |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_name, |
|
cache_dir=cache_dir, |
|
force_download=False, |
|
local_files_only=False |
|
) |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model.to(self.device) |
|
|
|
except Exception as e: |
|
|
|
print(f"Error loading model {model_name}: {str(e)}") |
|
print("Falling back to smaller model...") |
|
|
|
fallback_model = "sshleifer/distilbart-cnn-6-6" |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
fallback_model, |
|
cache_dir=cache_dir, |
|
local_files_only=False |
|
) |
|
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained( |
|
fallback_model, |
|
cache_dir=cache_dir, |
|
force_download=False, |
|
local_files_only=False |
|
) |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model.to(self.device) |
|
|
|
|
|
model_name = fallback_model |
|
|
|
self.model_loading_status["is_loading"] = False |
|
self.model_loading_status["progress"] = 100 |
|
|
|
|
|
self.model_name = model_name |
|
|
|
|
|
self.current_job = { |
|
"in_progress": False, |
|
"start_time": None, |
|
"input_word_count": 0, |
|
"estimated_time": 0, |
|
"stage": "", |
|
"progress": 0 |
|
} |
|
|
|
def clean_summary(self, summary): |
|
"""Clean and format the summary text""" |
|
|
|
summary = re.sub(r'^[,.\s]+', '', summary) |
|
|
|
|
|
if summary and len(summary) > 0: |
|
summary = summary[0].upper() + summary[1:] |
|
|
|
|
|
if summary and not any(summary.endswith(end) for end in ['.', '!', '?']): |
|
last_sentence_end = max( |
|
summary.rfind('.'), |
|
summary.rfind('!'), |
|
summary.rfind('?') |
|
) |
|
if last_sentence_end > 0: |
|
summary = summary[:last_sentence_end + 1] |
|
else: |
|
summary = summary + '.' |
|
|
|
return summary |
|
|
|
def get_status(self): |
|
"""Return the current status of the summarizer service""" |
|
status = { |
|
"model_loading": self.model_loading_status, |
|
"device": self.device, |
|
"current_job": self.current_job |
|
} |
|
|
|
|
|
if self.current_job["in_progress"] and self.current_job["start_time"]: |
|
elapsed = time.time() - self.current_job["start_time"] |
|
estimated = self.current_job["estimated_time"] |
|
remaining = max(0, estimated - elapsed) |
|
status["current_job"]["time_remaining"] = round(remaining, 1) |
|
|
|
|
|
if estimated > 0: |
|
progress = min(95, (elapsed / estimated) * 100) |
|
status["current_job"]["progress"] = round(progress, 0) |
|
|
|
return status |
|
|
|
def summarise(self, text, max_length=250, min_length=100, do_sample=True, temperature=1.2): |
|
""" |
|
Summarise the given text using the loaded model. |
|
|
|
Args: |
|
text (str): The text to summarise |
|
max_length (int): Maximum length of the summary in characters |
|
min_length (int): Minimum length of the summary in characters |
|
do_sample (bool): Whether to use sampling for generation |
|
temperature (float): Sampling temperature (higher = more random) |
|
|
|
Returns: |
|
dict: The generated summary and processing metadata |
|
""" |
|
logger.info(f"Starting summarization of text with {len(text)} characters") |
|
|
|
|
|
self.current_job = { |
|
"in_progress": True, |
|
"start_time": time.time(), |
|
"input_word_count": len(text.split()), |
|
"estimated_time": max(1, min(30, len(text.split()) / 500)), |
|
"stage": "Tokenizing input text", |
|
"progress": 5 |
|
} |
|
|
|
result = { |
|
"summary": "", |
|
"metadata": { |
|
"input_word_count": self.current_job["input_word_count"], |
|
"estimated_time_seconds": self.current_job["estimated_time"], |
|
"model_used": self.model_name, |
|
"processing_device": self.device |
|
} |
|
} |
|
|
|
try: |
|
|
|
text = self.preprocess_text(text) |
|
logger.info(f"After preprocessing: {len(text)} characters") |
|
|
|
|
|
inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True) |
|
input_ids = inputs.input_ids.to(self.device) |
|
|
|
|
|
result["metadata"]["input_token_count"] = len(input_ids[0]) |
|
result["metadata"]["truncated"] = len(input_ids[0]) == 1024 |
|
|
|
|
|
self.current_job["stage"] = "Generating summary" |
|
self.current_job["progress"] = 30 |
|
|
|
|
|
summary_ids = self.model.generate( |
|
input_ids, |
|
max_length=max_length, |
|
min_length=min_length, |
|
do_sample=do_sample, |
|
temperature=temperature, |
|
num_beams=5, |
|
early_stopping=True, |
|
no_repeat_ngram_size=3, |
|
length_penalty=2.0, |
|
top_k=50, |
|
top_p=0.95, |
|
) |
|
|
|
|
|
self.current_job["stage"] = "Post-processing summary" |
|
self.current_job["progress"] = 90 |
|
|
|
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
|
|
summary = self.clean_summary(summary) |
|
|
|
result["summary"] = summary |
|
result["metadata"]["output_word_count"] = len(summary.split()) |
|
result["metadata"]["compression_ratio"] = round(len(summary.split()) / self.current_job["input_word_count"] * 100, 1) |
|
|
|
logger.info(f"Generated summary with {len(summary)} characters") |
|
|
|
except Exception as e: |
|
logger.error(f"Error during summarization: {str(e)}") |
|
result["summary"] = "An error occurred during summarization. Please try again with a shorter text or different parameters." |
|
result["error"] = str(e) |
|
finally: |
|
|
|
self.current_job["in_progress"] = False |
|
self.current_job["stage"] = "Complete" |
|
self.current_job["progress"] = 100 |
|
|
|
return result |
|
|
|
def preprocess_text(self, text): |
|
"""Preprocess text to improve summarization quality.""" |
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
|
|
|
text = re.sub(r'Skip to (content|main).*?»', '', text) |
|
text = re.sub(r'Search for:.*?Search', '', text) |
|
text = re.sub(r'Menu.*?Resources', '', text, flags=re.DOTALL) |
|
|
|
|
|
text = re.sub(r'\d+ responses to.*?$', '', text, flags=re.DOTALL) |
|
|
|
|
|
text = re.sub(r'(Your email address will not be published|Required fields are marked).*?$', '', text, flags=re.DOTALL) |
|
|
|
|
|
if len(text) > 10000: |
|
text = text[:10000] |
|
|
|
return text |
|
|