ASRWhisper / app.py
ashpikachu2k1's picture
Update app.py
d21bd17 verified
import torch
import torchaudio
import gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Whisper fine-tuned Odia model
whisper_model_path = "./whisper-odia-final" # Change if needed
processor = WhisperProcessor.from_pretrained(whisper_model_path)
model = WhisperForConditionalGeneration.from_pretrained(whisper_model_path).to(device)
# Load IndicTrans2 multilingual model
trans_model_id = "ai4bharat/indictrans2-en-indic-dist-200M "
translator_tokenizer = AutoTokenizer.from_pretrained(trans_model_id, use_fast=False)
translator_model = AutoModelForSeq2SeqLM.from_pretrained(trans_model_id).to(device)
# Translation function with language tags
def translate_to_english(text):
if not text.strip():
return ""
# Add source and target language tokens
text_with_lang = f"<2en> {text.strip()}"
inputs = translator_tokenizer(text_with_lang, return_tensors="pt", padding=True).to(device)
output = translator_model.generate(**inputs, max_length=256)
translated = translator_tokenizer.batch_decode(output, skip_special_tokens=True)[0]
return translated
# ASR + Translation Pipeline
def transcribe(audio_path):
if audio_path is None:
return "No audio received.", ""
speech, sr = torchaudio.load(audio_path)
if sr != 16000:
speech = torchaudio.functional.resample(speech, sr, 16000)
input_features = processor(speech.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
predicted_ids = model.generate(input_features)
odia_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
english_text = translate_to_english(odia_text)
return odia_text, english_text
# Gradio UI
interface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(source="microphone", type="filepath", label="🎀 Record or Upload Odia Audio"),
outputs=[
gr.Textbox(label="πŸ“ Odia Transcription"),
gr.Textbox(label="🌐 English Translation")
],
title="Whisper Odia ASR + Translation",
description="πŸŽ™οΈ Speak in Odia β†’ Get Odia transcription β†’ Get English translation using IndicTrans2"
)
interface.launch()