Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from src.load_html import get_description_html | |
from src.audio_processor import AudioProcessor | |
from src.model.behaviour_model import get_behaviour_model | |
from transformers import ( | |
pipeline, | |
WavLMForSequenceClassification | |
) | |
# Gradio interface | |
def create_demo(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
segmentation_model = pipeline( | |
task="automatic-speech-recognition", | |
model="openai/whisper-large-v3-turbo", | |
tokenizer="openai/whisper-large-v3-turbo", | |
device=device | |
) | |
emotion_model = WavLMForSequenceClassification.from_pretrained("links-ads/kk-speech-emotion-recognition") | |
emotion_model.to(device) | |
emotion_model.eval() | |
behaviour_model = get_behaviour_model( | |
classifier_weights_path="src/model/classifier_weights.bin", | |
device=device, | |
) | |
audio_processor = AudioProcessor( | |
emotion_model=emotion_model, | |
segmentation_model=segmentation_model, | |
device=device, | |
behaviour_model=behaviour_model, | |
) | |
with gr.Blocks() as demo: | |
gr.HTML(get_description_html) | |
audio_input = gr.Audio(label="Upload Audio", type="filepath") | |
submit_button = gr.Button("Generate Graph") | |
graph_output = gr.Plot(label="Generated Graph") | |
submit_button.click( | |
fn=audio_processor, | |
inputs=audio_input, | |
outputs=graph_output | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch(show_api=False) |