roest-demo / app.py
saattrupdan's picture
fix: Force HUGGINGFACE_HUB_TOKEN to be present
243b3bb
raw
history blame
1.69 kB
"""Røst ASR demo."""
import os
import warnings
import gradio as gr
import numpy as np
import samplerate
import torch
from punctfix import PunctFixer
from transformers import pipeline
warnings.filterwarnings("ignore", category=FutureWarning)
TITLE = "Røst ASR Demo"
DESCRIPTION = """
This is a demo of the Danish speech recognition model
[Røst](https://huggingface.co/alexandrainst/roest-315m). Speak into the microphone and
see the text appear on the screen!
"""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
transcriber = pipeline(
task="automatic-speech-recognition",
model="alexandrainst/roest-315m",
device=device,
token=os.environ["HUGGINGFACE_HUB_TOKEN"],
)
transcription_fixer = PunctFixer(language="da", device=device)
def transcribe_audio(sampling_rate_and_audio: tuple[int, np.ndarray]) -> str:
"""Transcribe the audio.
Args:
sampling_rate_and_audio:
A tuple with the sampling rate and the audio.
Returns:
The transcription.
"""
sampling_rate, audio = sampling_rate_and_audio
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
audio = samplerate.resample(audio, 16_000 / sampling_rate, "sinc_best")
transcription = transcriber(inputs=audio)
if not isinstance(transcription, dict):
return ""
cleaned_transcription = transcription_fixer.punctuate(
text=transcription["text"]
)
return cleaned_transcription
demo = gr.Interface(
fn=transcribe_audio,
inputs=gr.Audio(sources=["microphone", "upload"]),
outputs="textbox",
title=TITLE,
description=DESCRIPTION,
allow_flagging="never",
)
demo.launch()