Relax-Teacher / app.py
Rimi98's picture
Update app.py
fcb35fa
raw
history blame
2.39 kB
import gradio as gr
import onnxruntime
from transformers import AutoTokenizer
import torch
import os
from transformers import pipeline
import subprocess
token = AutoTokenizer.from_pretrained('distilroberta-base')
inf_session = onnxruntime.InferenceSession('classifier1-quantized.onnx')
input_name = inf_session.get_inputs()[0].name
output_name = inf_session.get_outputs()[0].name
classes = ['Art', 'Astrology', 'Biology', 'Chemistry', 'Economics', 'History', 'Literature', 'Philosophy', 'Physics', 'Politics', 'Psychology', 'Sociology']
### --- Audio/Video to txt ---###
device = "cuda:0" if torch.cuda.is_available() else "cpu"
pipe = pipeline("automatic-speech-recognition",
model="openai/whisper-base.en",
chunk_length_s=30, device=device)
### --- Text Summary --- ###
summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=device)
def video_identity(video):
transcription = pipe(video)["text"]
return transcription
def summary(text):
text = text.split('.')
max_chunk = 500
current_chunk = 0
chunks = []
for t in text:
if len(chunks) == current_chunk + 1:
if len(chunks[current_chunk]) + len(t.split(' ')) <= max_chunk:
chunks[current_chunk].extend(t.split(' '))
else:
current_chunk += 1
chunks.append(t.split(' '))
else:
chunks.append(t.split(' '))
for chunk in range(len(chunks)):
chunks[chunk] =' '.join(chunks[chunk])
summ = summarizer(chunks,max_length = 100)
return summ
def classify(vid):
filename = vid[:-4]
subprocess.call(['ffmpeg','-i',f'{filename}.wav'])
full_text = video_identity(f'{filename}.wav')
sum = summary(full_text)[0]['summary_text']
input_ids = token(sum)['input_ids'][:512]
logits = inf_session.run([output_name],{input_name : [input_ids]})[0]
logits = torch.FloatTensor(logits)
probs = torch.sigmoid(logits)[0]
return full_text, sum, dict(zip(classes,map(float,probs)))
text1 = gr.Textbox(label="Text")
text2 = gr.Textbox(label="Summary")
iface = gr.Interface(fn=classify,
inputs=gr.inputs.Video(source="upload", type="filepath"),
outputs = [text1,text2,gr.outputs.Label(num_top_classes=3)])
iface.launch(inline=False)