Spaces:
Runtime error
Runtime error
File size: 5,255 Bytes
c510d8a b6198c3 c510d8a d6a85e3 c510d8a d6a85e3 c510d8a d6a85e3 c510d8a d6a85e3 c510d8a d6a85e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, TrainingArguments
from youtube_transcript_api import YouTubeTranscriptApi
from deepmultilingualpunctuation import PunctuationModel
from googletrans import Translator
import time
import torch
import re
# import httpcore
# setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy')
cp_aug = 'minnehwg/finetune-newwiki-summarization-ver-augmented2'
def load_model(cp):
tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
model = AutoModelForSeq2SeqLM.from_pretrained(cp)
return tokenizer, model
def summarize(text, model, tokenizer, num_beams=4, device='cpu'):
model.to(device)
inputs = tokenizer.encode(text, return_tensors="pt", max_length=1024, truncation=True, padding = True).to(device)
with torch.no_grad():
summary_ids = model.generate(inputs, max_length=256, num_beams=num_beams)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
def processed(text):
processed_text = text.replace('\n', ' ')
processed_text = processed_text.lower()
return processed_text
def get_subtitles(video_url):
try:
video_id = video_url.split("v=")[1]
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
subs = " ".join(entry['text'] for entry in transcript)
return transcript, subs
except Exception as e:
return [], f"An error occurred: {e}"
def restore_punctuation(text):
model = PunctuationModel()
result = model.restore_punctuation(text)
return result
def translate_long(text, language='vi'):
translator = Translator()
limit = 4700
chunks = []
current_chunk = ''
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
for sentence in sentences:
if len(current_chunk) + len(sentence) <= limit:
current_chunk += sentence.strip() + ' '
else:
chunks.append(current_chunk.strip())
current_chunk = sentence.strip() + ' '
if current_chunk:
chunks.append(current_chunk.strip())
translated_text = ''
for chunk in chunks:
try:
time.sleep(1)
translation = translator.translate(chunk, dest=language)
translated_text += translation.text + ' '
except Exception as e:
translated_text += chunk + ' '
return translated_text.strip()
def split_into_chunks(text, max_words=800, overlap_sentences=2):
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
chunks = []
current_chunk = []
current_word_count = 0
for sentence in sentences:
word_count = len(sentence.split())
if current_word_count + word_count <= max_words:
current_chunk.append(sentence)
current_word_count += word_count
else:
if len(current_chunk) >= overlap_sentences:
overlap = current_chunk[-overlap_sentences:]
print(f"Overlapping sentences: {' '.join(overlap)}")
chunks.append(' '.join(current_chunk))
current_chunk = current_chunk[-overlap_sentences:] + [sentence]
current_word_count = sum(len(sent.split()) for sent in current_chunk)
if current_chunk:
if len(current_chunk) >= overlap_sentences:
overlap = current_chunk[-overlap_sentences:]
print(f"Overlapping sentences: {' '.join(overlap)}")
chunks.append(' '.join(current_chunk))
return chunks
def post_processing(text):
sentences = re.split(r'(?<=[.!?])\s*', text)
for i in range(len(sentences)):
if sentences[i]:
sentences[i] = sentences[i][0].upper() + sentences[i][1:]
text = " ".join(sentences)
return text
def display(text):
sentences = re.split(r'(?<=[.!?])\s*', text)
unique_sentences = list(dict.fromkeys(sentences[:-1]))
formatted_sentences = [f"• {sentence}" for sentence in unique_sentences]
return formatted_sentences
def pipeline(url):
trans, sub = get_subtitles(url)
sub = restore_punctuation(sub)
vie_sub = translate_long(sub)
vie_sub = processed(vie_sub)
chunks = split_into_chunks(vie_sub, 700, 3)
sum_para = []
for i in chunks:
tmp = summarize(i, model_aug, tokenizer, num_beams=4)
sum_para.append(tmp)
sum = ''.join(sum_para)
del sub, vie_sub, sum_para, chunks
sum = post_processing(sum)
re = display(sum)
return re
import gradio as gr
cp_aug = 'minnehwg/finetune-newwiki-summarization-ver-augmented2'
def get_model():
checkpoint = cp_aug
tokenizer, model = load_model(checkpoint)
return tokenizer, model
tokenizer, model = get_model()
def generate_summary(input_text):
return pipeline(input_text)
demo = gr.Interface(
fn=generate_summary,
inputs=gr.Textbox(lines=2, placeholder="Enter your URL..."),
outputs=gr.Textbox(label="Generated Text"),
title="Chào mừng đến với hệ thống tóm tắt của Minne >.<",
description="Enter the URL to summarize and click 'Submit' to generate the summary."
)
demo.launch()
|