Dan Walsh
Initial deployment of AI Content Summariser API
9cf5fee
raw
history blame
3.17 kB
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
class SummariserService:
def __init__(self):
# Initialize with a smaller model for faster loading
model_name = "facebook/bart-large-cnn"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Move to GPU if available
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def summarise(self, text, max_length=150, min_length=50, do_sample=False, temperature=1.0):
"""
Summarise the given text using the loaded model.
Args:
text (str): The text to summarise
max_length (int): Maximum length of the summary
min_length (int): Minimum length of the summary
do_sample (bool): Whether to use sampling for generation
temperature (float): Sampling temperature (higher = more random)
Returns:
str: The generated summary
"""
# Ensure text is within model's max token limit
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
inputs = inputs.to(self.device)
# Set generation parameters
generation_params = {
"max_length": max_length,
"min_length": min_length,
"num_beams": 4,
"length_penalty": 2.0,
"early_stopping": True,
}
# Handle sampling and temperature
if do_sample:
try:
# First attempt: try with the requested temperature
generation_params["do_sample"] = True
generation_params["temperature"] = temperature
summary_ids = self.model.generate(
inputs["input_ids"],
**generation_params
)
except Exception as e:
# If that fails, try with default temperature (1.0)
print(f"Error with temperature {temperature}, falling back to default: {str(e)}")
generation_params["temperature"] = 1.0
try:
summary_ids = self.model.generate(
inputs["input_ids"],
**generation_params
)
except Exception:
# If sampling still fails, fall back to beam search without sampling
print("Sampling failed, falling back to beam search")
generation_params.pop("do_sample", None)
generation_params.pop("temperature", None)
summary_ids = self.model.generate(
inputs["input_ids"],
**generation_params
)
else:
# Standard beam search without sampling
summary_ids = self.model.generate(
inputs["input_ids"],
**generation_params
)
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary