|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
|
|
class SummariserService: |
|
def __init__(self): |
|
|
|
model_name = "facebook/bart-large-cnn" |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model.to(self.device) |
|
|
|
def summarise(self, text, max_length=150, min_length=50, do_sample=False, temperature=1.0): |
|
""" |
|
Summarise the given text using the loaded model. |
|
|
|
Args: |
|
text (str): The text to summarise |
|
max_length (int): Maximum length of the summary |
|
min_length (int): Minimum length of the summary |
|
do_sample (bool): Whether to use sampling for generation |
|
temperature (float): Sampling temperature (higher = more random) |
|
|
|
Returns: |
|
str: The generated summary |
|
""" |
|
|
|
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=1024) |
|
inputs = inputs.to(self.device) |
|
|
|
|
|
generation_params = { |
|
"max_length": max_length, |
|
"min_length": min_length, |
|
"num_beams": 4, |
|
"length_penalty": 2.0, |
|
"early_stopping": True, |
|
} |
|
|
|
|
|
if do_sample: |
|
try: |
|
|
|
generation_params["do_sample"] = True |
|
generation_params["temperature"] = temperature |
|
summary_ids = self.model.generate( |
|
inputs["input_ids"], |
|
**generation_params |
|
) |
|
except Exception as e: |
|
|
|
print(f"Error with temperature {temperature}, falling back to default: {str(e)}") |
|
generation_params["temperature"] = 1.0 |
|
try: |
|
summary_ids = self.model.generate( |
|
inputs["input_ids"], |
|
**generation_params |
|
) |
|
except Exception: |
|
|
|
print("Sampling failed, falling back to beam search") |
|
generation_params.pop("do_sample", None) |
|
generation_params.pop("temperature", None) |
|
summary_ids = self.model.generate( |
|
inputs["input_ids"], |
|
**generation_params |
|
) |
|
else: |
|
|
|
summary_ids = self.model.generate( |
|
inputs["input_ids"], |
|
**generation_params |
|
) |
|
|
|
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
return summary |
|
|