ntviet's picture
Update app.py
e40922c verified
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import gradio as gr
import torch
import librosa
import numpy as np
# Cấu hình thiết bị
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Tải model và processor
model = WhisperForConditionalGeneration.from_pretrained("ntviet/whisper-small-hre5.2").to(device)
processor = WhisperProcessor.from_pretrained("ntviet/whisper-small-hre5.2")
# Tắt tính năng forced_decoder_ids trong generation config
if hasattr(model.generation_config, "forced_decoder_ids"):
model.generation_config.forced_decoder_ids = None
def transcribe(audio_path):
try:
# Đọc file âm thanh với librosa
audio, sr = librosa.load(audio_path, sr=16000)
# Chuẩn hóa dữ liệu âm thanh
audio = librosa.util.normalize(audio) * 0.9 # Giảm volume để tránh clipping
# Xử lý âm thanh
inputs = processor(
audio,
sampling_rate=16000,
return_tensors="pt"
).to(device)
# Generate với cấu hình tùy chỉnh
outputs = model.generate(
inputs.input_features,
generation_config=model.generation_config
)
# Giải mã kết quả
text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
return text
except Exception as e:
return f"Lỗi khi xử lý audio: {str(e)}"
# # Giao diện Gradio
# with gr.Blocks() as demo:
# gr.Markdown("""
# # Nhận dạng giọng nói tiếng Việt
# Model: whisper-small-hre5.2 (đã fine-tune)
# """)
# with gr.Row():
# audio_input = gr.Audio(
# sources=["upload", "microphone"],
# type="filepath",
# label="Tải lên hoặc ghi âm"
# )
# output_text = gr.Textbox(label="Kết quả nhận dạng")
# submit_btn = gr.Button("Bắt đầu nhận dạng")
# submit_btn.click(
# fn=transcribe,
# inputs=audio_input,
# outputs=output_text,
# api_name="/predict"
# )
# demo.launch(debug=True, show_error=True)
# Giao diện Gradio cho phép ghi âm hoặc upload
demo = gr.Interface(
fn=transcribe,
inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="🎙️ Upload hoặc ghi âm (.mp3)"),
outputs=gr.Textbox(label="📝 Kết quả chuyển văn bản"),
title="Whisper HRE Demo",
description="🔤 Nhận dạng tiếng H'Re bằng mô hình Whisper fine-tuned.",
theme="soft",
api_name="/predict" # <-- Đảm bảo endpoint cho gradio_client
)
# Chạy app với hiển thị lỗi đầy đủ
demo.launch(show_error=True)