ASR / app.py
NightPrince's picture
Update app.py
98c9824 verified
raw
history blame
1.07 kB
import gradio as gr
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
import numpy as np
# Load the pre-trained model and processor
model_name = "facebook/s2t-wav2vec2-large-en-ar"
model = Wav2Vec2ForCTC.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name)
# Define a function for the ASR model
def transcribe(audio):
# Convert the audio into a format compatible with the processor
if isinstance(audio, np.ndarray):
audio = audio.flatten() # Ensure it's a 1D array
# Process the audio
inputs = processor(audio, return_tensors="pt", sampling_rate=16000)
# Get the model's predictions
logits = model(input_values=inputs.input_values).logits
# Decode the predicted text
predicted_ids = logits.argmax(dim=-1)
transcription = processor.decode(predicted_ids[0])
return transcription
# Define the Gradio interface
interface = gr.Interface(fn=transcribe, inputs=gr.Audio(type="numpy"), outputs="text")
# Launch the Gradio interface
interface.launch()