File size: 3,737 Bytes
d83d604
98ac441
d1155a6
5813702
ff5002a
 
d1155a6
53d5734
ff5002a
d1155a6
73fdc8f
53d5734
5813702
861709c
daec533
861709c
 
9388e0e
5813702
73fdc8f
5813702
0c65a3b
53d5734
98ac441
53d5734
 
 
 
 
 
 
f13147a
ff5002a
53d5734
 
 
 
 
 
 
 
98ac441
53d5734
5813702
 
 
d1155a6
 
 
 
 
53d5734
d1155a6
53d5734
 
d1155a6
 
53d5734
5813702
d83d604
53d5734
6ed741d
5813702
 
 
98ac441
5813702
 
 
6ed741d
5813702
d83d604
d1155a6
5813702
d1155a6
6ed741d
26a1a9d
 
 
 
 
5813702
d1155a6
26a1a9d
 
5813702
 
 
26a1a9d
6ed741d
d1155a6
5813702
 
 
 
 
 
53d5734
 
fd8e8ce
 
53d5734
 
fd8e8ce
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from nltk.corpus import stopwords
import nltk

# Download NLTK stopwords
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

# Best lightweight summarization models
model_choices = {
    "DistilBART CNN (sshleifer/distilbart-cnn-12-6)": "sshleifer/distilbart-cnn-12-6",
    "T5 Small (t5-small)": "t5-small",
    "T5 Base (t5-base)": "t5-base",
    "Flan-T5 Base (google/flan-t5-base)": "google/flan-t5-base",
    "DistilBART XSum (sshleifer/distilbart-xsum-12-6)": "sshleifer/distilbart-xsum-12-6"
}

model_cache = {}

# Clean input text (remove stopwords and SKUs/product codes)
def clean_text(input_text):
    # Remove simple SKU codes (e.g., ST1642, AB1234, etc.)
    cleaned_text = re.sub(r'\b[A-Za-z]{2,}[0-9]{3,}\b', '', input_text)  # Alphanumeric SKU
    
    # Replace special characters with a space
    cleaned_text = re.sub(r'[^A-Za-z0-9\s]', ' ', cleaned_text)
    
    # Tokenize the input text and remove stop words
    words = cleaned_text.split()
    words = [word for word in words if word.lower() not in stop_words]
    
    # Rebuild the cleaned text
    cleaned_text = " ".join(words)
    
    # Strip leading and trailing spaces
    cleaned_text = cleaned_text.strip()
    
    return cleaned_text

# Load model and tokenizer
def load_model(model_name):
    if model_name not in model_cache:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)

        # Warm-up
        dummy_input = tokenizer("summarize: warm up", return_tensors="pt").input_ids.to(device)
        model.generate(dummy_input, max_length=10)

        model_cache[model_name] = (tokenizer, model)
    return model_cache[model_name]

# Summarize the text using a selected model
def summarize_text(input_text, model_label, char_limit):
    if not input_text.strip():
        return "Please enter some text."

    input_text = clean_text(input_text)
    model_name = model_choices[model_label]
    tokenizer, model = load_model(model_name)

    if "t5" in model_name.lower() or "flan" in model_name.lower():
        input_text = "summarize: " + input_text

    device = model.device
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
    input_ids = inputs["input_ids"].to(device)

    # Adjust the length constraints to make sure min_length < max_length
    max_len = 30  # Set your desired max length
    min_len = 5   # Ensure min_length is smaller than max_length

    # Generate summary
    summary_ids = model.generate(
        input_ids,
        max_length=max_len,
        min_length=min_len,
        do_sample=False
    )

    # Decode the summary
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary[:char_limit].strip()

# Gradio UI
iface = gr.Interface(
    fn=summarize_text,
    inputs=[
        gr.Textbox(lines=6, label="Enter text to summarize"),
        gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="DistilBART CNN (sshleifer/distilbart-cnn-12-6)"),
        gr.Slider(minimum=30, maximum=200, value=80, step=1, label="Max Character Limit")
    ],
    outputs=gr.Textbox(lines=3, label="Summary (truncated to character limit)"),
    title="🚀 Fast Lightweight Summarizer (GPU Optimized)",
    description="Summarize text quickly using compact models ideal for low-latency and ZeroGPU Spaces."
)

iface.launch()