Spaces:
Sleeping
Sleeping
# gradio_app.py | |
import gradio as gr | |
import io | |
import os | |
import torch | |
from parler_tts import ParlerTTSForConditionalGeneration | |
from transformers import AutoTokenizer, AutoModel # CHANGED: Using AutoModel as per model card | |
import numpy as np | |
import google.generativeai as genai | |
import asyncio | |
import librosa | |
import torchaudio # Often used by models like this for audio loading/processing internally or as input type | |
# --- Configuration --- | |
ASR_MODEL_NAME = "ai4bharat/indic-conformer-600m-multilingual" | |
TARGET_SAMPLE_RATE = 16000 # Model expects 16kHz | |
TTS_MODEL_NAME = "ai4bharat/indic-parler-tts" | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyD6x3Yoby4eQ6QL2kaaG_Rz3fG3rh7wPB8") | |
GEMINI_MODEL_NAME_GRADIO = "gemini-1.5-flash-latest" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# torch_dtype for ParlerTTS, Gemini etc. For ASR model, it might handle its own precision. | |
# --- Global Model Variables --- | |
asr_model_gradio = None # This will be the AutoModel instance | |
gemini_model_instance_gradio = None | |
tts_model_gradio = None | |
tts_tokenizer_gradio = None # For ParlerTTS | |
# --- Model Loading & API Configuration --- | |
def load_all_resources_gradio(): | |
global asr_model_gradio, tts_model_gradio, tts_tokenizer_gradio, gemini_model_instance_gradio | |
print(f"Gradio: Loading resources. ASR will be on device: {DEVICE}") | |
if asr_model_gradio is None: | |
print(f"Gradio: Loading ASR model: {ASR_MODEL_NAME} using AutoModel") | |
try: | |
# Load using AutoModel as per the model card's implication | |
asr_model_gradio = AutoModel.from_pretrained(ASR_MODEL_NAME, trust_remote_code=True) | |
asr_model_gradio.to(DEVICE) # Move model to device | |
# The model might handle its own precision (e.g. .half()) internally if `trust_remote_code` allows | |
# Or you might need to call asr_model_gradio.half() if it supports it and you're on CUDA. | |
if DEVICE == "cuda" and hasattr(asr_model_gradio, 'half'): | |
print("Gradio: Applying .half() to ASR model.") | |
asr_model_gradio.half() | |
asr_model_gradio.eval() | |
print(f"Gradio: ASR model ({ASR_MODEL_NAME}) loaded using AutoModel.") | |
except Exception as e: | |
print(f"Gradio: Failed to load ASR model {ASR_MODEL_NAME} using AutoModel: {e}") | |
import traceback | |
traceback.print_exc() | |
asr_model_gradio = None | |
if tts_model_gradio is None: # ParlerTTS loading | |
print(f"Gradio: Loading IndicParler-TTS model: {TTS_MODEL_NAME}") | |
# Ensure ParlerTTS specific tokenizer is loaded for TTS | |
# Note: ASR model might have its own internal tokenizer/processor handled by its custom code | |
tts_parler_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True) | |
tts_model_gradio = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL_NAME, trust_remote_code=True).to(DEVICE) | |
tts_tokenizer_gradio = tts_parler_tokenizer | |
print("Gradio: IndicParler-TTS model loaded.") | |
if gemini_model_instance_gradio is None: # Gemini loading | |
if not GEMINI_API_KEY: | |
print("Gradio: GEMINI_API_KEY not found. LLM functionality via Gemini will be limited.") | |
else: | |
try: | |
genai.configure(api_key=GEMINI_API_KEY) | |
gemini_model_instance_gradio = genai.GenerativeModel(GEMINI_MODEL_NAME_GRADIO) | |
print(f"Gradio: Gemini API configured with model: {GEMINI_MODEL_NAME_GRADIO}") | |
except Exception as e: | |
print(f"Gradio: Failed to configure Gemini API: {e}") | |
gemini_model_instance_gradio = None | |
print("Gradio: All resources loaded (or attempted).") | |
# --- Helper Functions --- | |
def transcribe_audio_gradio(audio_input_tuple): | |
if asr_model_gradio is None: | |
return f"Error: ASR model ({ASR_MODEL_NAME}) not loaded." | |
if audio_input_tuple is None: | |
print("Gradio: No audio provided to transcribe_audio_gradio.") | |
return "No audio provided." | |
sample_rate, audio_numpy = audio_input_tuple | |
if audio_numpy is None or audio_numpy.size == 0: | |
print("Gradio: Audio numpy array is empty.") | |
return "Empty audio received." | |
# Ensure audio is mono float32, which is a common expectation | |
if audio_numpy.ndim > 1: | |
if audio_numpy.shape[0] == 2 and audio_numpy.ndim == 2: | |
audio_numpy = librosa.to_mono(audio_numpy) | |
elif audio_numpy.shape[1] == 2 and audio_numpy.ndim == 2: | |
audio_numpy = np.mean(audio_numpy, axis=1) | |
if audio_numpy.dtype != np.float32: | |
if np.issubdtype(audio_numpy.dtype, np.integer): | |
audio_numpy = audio_numpy.astype(np.float32) / np.iinfo(audio_numpy.dtype).max | |
else: | |
audio_numpy = audio_numpy.astype(np.float32) | |
# Resample to TARGET_SAMPLE_RATE (16kHz) | |
if sample_rate != TARGET_SAMPLE_RATE: | |
print(f"Gradio: Resampling audio from {sample_rate} Hz to {TARGET_SAMPLE_RATE} Hz.") | |
try: | |
audio_numpy = librosa.resample(y=audio_numpy, orig_sr=sample_rate, target_sr=TARGET_SAMPLE_RATE) | |
# After resampling, the audio_numpy is at TARGET_SAMPLE_RATE | |
except Exception as e: | |
print(f"Gradio: Error during resampling: {e}") | |
return f"Error during audio resampling: {str(e)}" | |
try: | |
print(f"Gradio: Preparing to transcribe with {ASR_MODEL_NAME}. Input audio shape: {audio_numpy.shape}") | |
# The model card example `model(wav, "hi", "ctc")` implies it might take a waveform tensor. | |
# We have a numpy array. We need to convert it to a PyTorch tensor. | |
# The model card uses torchaudio.load which returns a tensor. | |
# Let's convert our numpy array to a tensor and ensure it's on the correct device. | |
# Ensure the audio_numpy is 1D as expected by many ASR models for a single channel | |
if audio_numpy.ndim > 1: | |
audio_numpy = audio_numpy.squeeze() # Attempt to remove singleton dimensions | |
if audio_numpy.ndim > 1 : # If still more than 1D, problem | |
print(f"Gradio: Audio numpy array for ASR has unexpected dimensions after processing: {audio_numpy.shape}") | |
return "Error: Audio processing resulted in unexpected dimensions." | |
wav_tensor = torch.from_numpy(audio_numpy).to(DEVICE) | |
# The model might expect a batch dimension, e.g., [1, num_samples] | |
if wav_tensor.ndim == 1: | |
wav_tensor = wav_tensor.unsqueeze(0) | |
print(f"Gradio: Transcribing with {ASR_MODEL_NAME} using CTC. Input tensor shape: {wav_tensor.shape}") | |
# Perform ASR with CTC decoding (you can choose "rnnt" if preferred and supported) | |
# The language code "hi" is for Hindi. You might want to make this configurable | |
# or see if the model supports language auto-detection if you pass None or omit it. | |
# For now, assuming "hi" or that the model handles mixed language if lang_id is not strictly enforced. | |
# The model card doesn't specify if language ID is optional or how auto-detection works. | |
# Let's try "auto" or a common language like "en" or "hi" to start. | |
# The model card indicates training on 22 languages, so it's multilingual. | |
# If language_id is required, you'll need to provide it. | |
# Let's assume for now we try with a common Indian language or let the model try to auto-detect if "auto" or None is valid. | |
# The snippet "model(wav, "hi", "ctc")" is specific. | |
# The `model()` call is synchronous. Gradio handles this in a thread. | |
with torch.no_grad(): # Good practice for inference | |
transcription_result = asr_model_gradio(wav_tensor, "hi", "ctc") # Using lang_id="hi" and strategy="ctc" as per example | |
# The output format needs to be checked. The model card implies it's the transcribed string directly. | |
# It might be a list of transcriptions if batching occurs, or a dict. | |
if isinstance(transcription_result, list) and len(transcription_result) > 0: | |
transcribed_text = transcription_result[0] # Assuming first result for non-batched input | |
elif isinstance(transcription_result, str): | |
transcribed_text = transcription_result | |
else: | |
print(f"Gradio: Unexpected ASR result format: {type(transcription_result)}, value: {transcription_result}") | |
transcribed_text = "ASR result format not recognized." | |
transcribed_text = transcribed_text.strip() | |
print(f"Gradio: Transcription ({ASR_MODEL_NAME}, CTC): {transcribed_text}") | |
return transcribed_text if transcribed_text else "Transcription resulted in empty text." | |
except Exception as e: | |
print(f"Gradio: Error during {ASR_MODEL_NAME} transcription (AutoModel callable): {e}") | |
import traceback | |
traceback.print_exc() | |
return f"Error during transcription ({ASR_MODEL_NAME}): {str(e)}" | |
# ... (Gemini LLM and TTS functions remain the same) ... | |
def generate_gemini_response_gradio(text_input: str): | |
if not gemini_model_instance_gradio: | |
return "Error: Gemini LLM not configured or API key missing." | |
if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "No audio provided" in text_input or "Transcription resulted in empty text" in text_input or "Empty audio received" in text_input or "ASR result format not recognized" in text_input: | |
print(f"Gradio: Invalid input to Gemini: '{text_input}'. Skipping LLM response.") | |
return "LLM (Gemini) skipped due to transcription issue or no input." | |
try: | |
print(f"Gradio: Sending to Gemini: '{text_input}'") | |
full_prompt = f"User: {text_input}\nAssistant:" | |
response = gemini_model_instance_gradio.generate_content(full_prompt) | |
response_text = "" | |
if response.candidates and response.candidates[0].content.parts: | |
response_text = response.candidates[0].content.parts[0].text.strip() | |
else: | |
feedback_info = "" | |
if hasattr(response, 'prompt_feedback') and response.prompt_feedback: | |
feedback_info = f" Feedback: {response.prompt_feedback}" | |
print(f"Gradio: Gemini response did not contain expected content.{feedback_info}") | |
response_text = f"I'm sorry, I couldn't generate a response for that (Gemini).{feedback_info}" | |
print(f"Gradio: Gemini LLM Response: {response_text}") | |
return response_text if response_text else "Gemini LLM generated an empty response." | |
except Exception as e: | |
print(f"Gradio: Error during Gemini LLM generation: {e}") | |
import traceback | |
traceback.print_exc() | |
return f"Error during Gemini LLM generation: {str(e)}" | |
def synthesize_speech_gradio(text_input: str, description: str = "A clear, female voice speaking in English."): | |
if tts_model_gradio is None or tts_tokenizer_gradio is None: | |
return "Error: TTS model or its tokenizer not loaded." | |
if not isinstance(text_input, str) or not text_input.strip() or text_input.startswith("Error:") or "LLM skipped" in text_input or "generated an empty response" in text_input or "not configured" in text_input or "ASR result format not recognized" in text_input : | |
print(f"Gradio: Invalid input to TTS: '{text_input}'. Skipping synthesis.") | |
return "TTS skipped due to LLM issue or no input." | |
try: | |
print(f"Gradio: Synthesizing speech for: '{text_input}'") | |
description_tokenized = tts_tokenizer_gradio(description, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
description_ids = description_tokenized.input_ids.to(DEVICE) | |
description_attention_mask = description_tokenized.attention_mask.to(DEVICE) | |
prompt_tokenized = tts_tokenizer_gradio(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
prompt_ids = prompt_tokenized.input_ids.to(DEVICE) | |
if prompt_ids.shape[-1] == 0: # Check if tokenized prompt is empty | |
print(f"Gradio: Tokenized prompt for TTS is empty. Text was: '{text_input}'. Skipping synthesis.") | |
return "TTS skipped: Input text resulted in empty tokens." | |
generation = tts_model_gradio.generate( | |
input_ids=description_ids, | |
attention_mask=description_attention_mask, | |
prompt_input_ids=prompt_ids, | |
do_sample=True, temperature=0.7, top_k=50, top_p=0.95 | |
).cpu().numpy().squeeze() | |
sampling_rate = tts_model_gradio.config.sampling_rate | |
print(f"Gradio: Speech synthesized. Array shape: {generation.shape}, Sample rate: {sampling_rate}") | |
return (sampling_rate, generation) | |
except Exception as e: | |
print(f"Gradio: Error during speech synthesis: {e}") | |
import traceback | |
traceback.print_exc() | |
if "You need to specify either `text` or `text_target`" in str(e): | |
return "Error in TTS: Model requires 'text' or 'text_target'. Input might be too short or problematic." | |
return f"Error during speech synthesis: {str(e)}" | |
# --- Gradio Interface Definition --- | |
load_all_resources_gradio() | |
def full_pipeline_gradio(audio_input): | |
transcribed_text_output = transcribe_audio_gradio(audio_input) | |
print(f"DEBUG full_pipeline_gradio - Step 1 (Transcription): '{transcribed_text_output}' (type: {type(transcribed_text_output)})") | |
llm_response_text_output = generate_gemini_response_gradio(transcribed_text_output) | |
print(f"DEBUG full_pipeline_gradio - Step 2 (LLM Response): '{llm_response_text_output}' (type: {type(llm_response_text_output)})") | |
tts_synthesis_result = synthesize_speech_gradio(llm_response_text_output) | |
final_audio_output = None | |
if isinstance(tts_synthesis_result, tuple) and len(tts_synthesis_result) == 2 and isinstance(tts_synthesis_result[1], np.ndarray): | |
final_audio_output = tts_synthesis_result | |
print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Success): Audio tuple with shape {tts_synthesis_result[1].shape if isinstance(tts_synthesis_result[1], np.ndarray) else 'N/A'}") | |
else: | |
error_message_from_tts = str(tts_synthesis_result) if isinstance(tts_synthesis_result, str) else "TTS synthesis failed or returned unexpected type" | |
print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Failed/Non-audio): {error_message_from_tts}. Providing silent audio.") | |
# Append TTS error to LLM text only if LLM text was valid | |
if llm_response_text_output and not llm_response_text_output.startswith("Error:") and "LLM skipped" not in llm_response_text_output and "ASR result format not recognized" not in llm_response_text_output: | |
llm_response_text_output = f"{llm_response_text_output} | (TTS Problem: {error_message_from_tts})" | |
elif not llm_response_text_output or llm_response_text_output.startswith("Error:") or "LLM skipped" in llm_response_text_output or "ASR result format not recognized" in llm_response_text_output: | |
# If LLM already had an error, just keep that error, maybe note TTS also had an issue | |
llm_response_text_output = f"{llm_response_text_output} (TTS also had an issue: {error_message_from_tts})" | |
default_sample_rate = tts_model_gradio.config.sampling_rate if tts_model_gradio and hasattr(tts_model_gradio, 'config') else TARGET_SAMPLE_RATE | |
final_audio_output = (default_sample_rate, np.array([0.0], dtype=np.float32)) | |
print(f"DEBUG full_pipeline_gradio - Step 3 (TTS Fallback): Silent audio tuple") | |
print(f"DEBUG full_pipeline_gradio - RETURNING: Transcription='{transcribed_text_output}', LLM_Text='{llm_response_text_output}', Audio_Type={type(final_audio_output)}") | |
return transcribed_text_output, llm_response_text_output, final_audio_output | |
with gr.Blocks(title="Conversational AI Demo") as demo: | |
gr.Markdown("# Conversational AI Demo (STT -> Gemini LLM -> TTS)") | |
with gr.Row(): | |
audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Speak Here") | |
process_button = gr.Button("Process Audio") | |
with gr.Accordion("Outputs", open=True): | |
transcription_out = gr.Textbox(label="You Said (Transcription)", lines=2) | |
llm_response_out = gr.Textbox(label="Gemini Assistant Says (Text)", lines=5) | |
audio_out = gr.Audio(label="Assistant Says (Audio)") | |
process_button.click( | |
fn=full_pipeline_gradio, | |
inputs=[audio_in], | |
outputs=[transcription_out, llm_response_out, audio_out] | |
) | |
gr.Markdown("---") | |
gr.Markdown("### How to Use:") | |
gr.Markdown("1. Ensure your `GEMINI_API_KEY` environment variable is set.") | |
gr.Markdown("2. Click into the 'Speak Here' box and record your audio.") | |
gr.Markdown("3. Click the 'Process Audio' button.") | |
gr.Markdown("4. View the transcription, Gemini's text response, and listen to the audio response.") | |
if __name__ == "__main__": | |
demo.launch(share=False) |