Spaces:
Runtime error
Runtime error
import torch | |
import torchaudio | |
import gradio as gr | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load Whisper fine-tuned Odia model | |
whisper_model_path = "./whisper-odia-final" # Change if needed | |
processor = WhisperProcessor.from_pretrained(whisper_model_path) | |
model = WhisperForConditionalGeneration.from_pretrained(whisper_model_path).to(device) | |
# Load IndicTrans2 multilingual model | |
trans_model_id = "ai4bharat/indictrans2-en-indic-dist-200M " | |
translator_tokenizer = AutoTokenizer.from_pretrained(trans_model_id, use_fast=False) | |
translator_model = AutoModelForSeq2SeqLM.from_pretrained(trans_model_id).to(device) | |
# Translation function with language tags | |
def translate_to_english(text): | |
if not text.strip(): | |
return "" | |
# Add source and target language tokens | |
text_with_lang = f"<2en> {text.strip()}" | |
inputs = translator_tokenizer(text_with_lang, return_tensors="pt", padding=True).to(device) | |
output = translator_model.generate(**inputs, max_length=256) | |
translated = translator_tokenizer.batch_decode(output, skip_special_tokens=True)[0] | |
return translated | |
# ASR + Translation Pipeline | |
def transcribe(audio_path): | |
if audio_path is None: | |
return "No audio received.", "" | |
speech, sr = torchaudio.load(audio_path) | |
if sr != 16000: | |
speech = torchaudio.functional.resample(speech, sr, 16000) | |
input_features = processor(speech.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device) | |
predicted_ids = model.generate(input_features) | |
odia_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
english_text = translate_to_english(odia_text) | |
return odia_text, english_text | |
# Gradio UI | |
interface = gr.Interface( | |
fn=transcribe, | |
inputs=gr.Audio(source="microphone", type="filepath", label="π€ Record or Upload Odia Audio"), | |
outputs=[ | |
gr.Textbox(label="π Odia Transcription"), | |
gr.Textbox(label="π English Translation") | |
], | |
title="Whisper Odia ASR + Translation", | |
description="ποΈ Speak in Odia β Get Odia transcription β Get English translation using IndicTrans2" | |
) | |
interface.launch() |