Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
class SummariserService: | |
def __init__(self): | |
# Initialize with a smaller model for faster loading | |
model_name = "facebook/bart-large-cnn" | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# Move to GPU if available | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model.to(self.device) | |
def summarise(self, text, max_length=150, min_length=50, do_sample=False, temperature=1.0): | |
""" | |
Summarise the given text using the loaded model. | |
Args: | |
text (str): The text to summarise | |
max_length (int): Maximum length of the summary | |
min_length (int): Minimum length of the summary | |
do_sample (bool): Whether to use sampling for generation | |
temperature (float): Sampling temperature (higher = more random) | |
Returns: | |
str: The generated summary | |
""" | |
# Ensure text is within model's max token limit | |
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=1024) | |
inputs = inputs.to(self.device) | |
# Set generation parameters | |
generation_params = { | |
"max_length": max_length, | |
"min_length": min_length, | |
"num_beams": 4, | |
"length_penalty": 2.0, | |
"early_stopping": True, | |
} | |
# Handle sampling and temperature | |
if do_sample: | |
try: | |
# First attempt: try with the requested temperature | |
generation_params["do_sample"] = True | |
generation_params["temperature"] = temperature | |
summary_ids = self.model.generate( | |
inputs["input_ids"], | |
**generation_params | |
) | |
except Exception as e: | |
# If that fails, try with default temperature (1.0) | |
print(f"Error with temperature {temperature}, falling back to default: {str(e)}") | |
generation_params["temperature"] = 1.0 | |
try: | |
summary_ids = self.model.generate( | |
inputs["input_ids"], | |
**generation_params | |
) | |
except Exception: | |
# If sampling still fails, fall back to beam search without sampling | |
print("Sampling failed, falling back to beam search") | |
generation_params.pop("do_sample", None) | |
generation_params.pop("temperature", None) | |
summary_ids = self.model.generate( | |
inputs["input_ids"], | |
**generation_params | |
) | |
else: | |
# Standard beam search without sampling | |
summary_ids = self.model.generate( | |
inputs["input_ids"], | |
**generation_params | |
) | |
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |