madankn79 commited on
Commit
5813702
·
1 Parent(s): 5efa5bb
Files changed (1) hide show
  1. app.py +48 -12
app.py CHANGED
@@ -1,19 +1,55 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
 
 
 
 
 
 
 
4
 
5
- summarizer = pipeline("summarization", model="google/pegasus-xsum")
 
6
 
7
- def process(text):
8
- summary = summarizer(text, max_length=10, min_length=5, do_sample=False, clean_up_tokenization_spaces=True, truncation=True)
9
- summary = summary[:65]
10
- return summary
 
 
11
 
12
- demo = gr.Interface(
13
- fn=process,
14
- inputs=gr.Textbox(label="Input Text"),
15
- outputs="json"
16
- )
 
 
 
 
 
17
 
18
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
+ # Supported summarization models
5
+ model_choices = {
6
+ "Pegasus (google/pegasus-xsum)": "google/pegasus-xsum",
7
+ "BART (facebook/bart-large-cnn)": "facebook/bart-large-cnn",
8
+ "T5 (t5-small)": "t5-small"
9
+ }
10
 
11
+ # Cache for loaded models/tokenizers
12
+ model_cache = {}
13
 
14
+ def load_model(model_name):
15
+ if model_name not in model_cache:
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
+ model_cache[model_name] = (tokenizer, model)
19
+ return model_cache[model_name]
20
 
21
+ # Summarization function
22
+ def summarize_text(input_text, model_label):
23
+ if not input_text.strip():
24
+ return "Please enter some text."
25
+
26
+ model_name = model_choices[model_label]
27
+ tokenizer, model = load_model(model_name)
28
+
29
+ if "t5" in model_name.lower():
30
+ input_text = "summarize: " + input_text
31
 
32
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
33
+ summary_ids = model.generate(
34
+ inputs["input_ids"],
35
+ max_length=20, # Approximate for 65 characters
36
+ min_length=5,
37
+ do_sample=False
38
+ )
39
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
40
+
41
+ return summary[:65] # Ensure character limit
42
+
43
+ # Gradio UI
44
+ iface = gr.Interface(
45
+ fn=summarize_text,
46
+ inputs=[
47
+ gr.Textbox(lines=6, label="Enter text to summarize"),
48
+ gr.Dropdown(choices=list(model_choices.keys()), label="Choose summarization model", value="Pegasus (google/pegasus-xsum)")
49
+ ],
50
+ outputs=gr.Textbox(lines=2, label="Summary (max 65 characters)"),
51
+ title="Short Text Summarizer",
52
+ description="Summarizes input text to under 65 characters using a selected model."
53
+ )
54
 
55
+ iface.launch()