Futuresony commited on
Commit
eb87835
Β·
verified Β·
1 Parent(s): 1f2c99e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
+ from huggingface_hub import InferenceClient
6
+ from ttsmms import download, TTS
7
+ from langdetect import detect
8
+ import os
9
+ import wave
10
+ import numpy as np
11
+
12
+ # === Step 1: Load ASR Model ===
13
+ asr_model_name = "Futuresony/Future-sw_ASR-24-02-2025"
14
+ processor = Wav2Vec2Processor.from_pretrained(asr_model_name)
15
+ asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_name)
16
+
17
+ # === Step 2: Load Text Generation Model ===
18
+ client = InferenceClient("unsloth/gemma-3-1b-it")
19
+ def format_prompt(user_input):
20
+ return f"{user_input}"
21
+
22
+ # === Step 3: Load TTS Models ===
23
+ swahili_dir = download("swh", "./data/swahili")
24
+ english_dir = download("eng", "./data/english")
25
+ swahili_tts = TTS(swahili_dir)
26
+ english_tts = TTS(english_dir)
27
+
28
+ # === Step 4: Generate silent fallback audio ===
29
+ def create_silent_wav(filename="./error.wav", duration_sec=1.0, sample_rate=16000):
30
+ if not os.path.exists(filename):
31
+ silence = np.zeros(int(sample_rate * duration_sec), dtype=np.int16)
32
+ with wave.open(filename, 'w') as wf:
33
+ wf.setnchannels(1)
34
+ wf.setsampwidth(2)
35
+ wf.setframerate(sample_rate)
36
+ wf.writeframes(silence.tobytes())
37
+
38
+ create_silent_wav() # Call once at startup
39
+
40
+ # === Step 5: Transcription Function ===
41
+ def transcribe(audio_file):
42
+ try:
43
+ speech_array, sample_rate = torchaudio.load(audio_file)
44
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
45
+ speech_array = resampler(speech_array).squeeze().numpy()
46
+ input_values = processor(speech_array, sampling_rate=16000, return_tensors="pt").input_values
47
+ with torch.no_grad():
48
+ logits = asr_model(input_values).logits
49
+ predicted_ids = torch.argmax(logits, dim=-1)
50
+ transcription = processor.batch_decode(predicted_ids)[0]
51
+ return transcription
52
+ except Exception as e:
53
+ print("ASR Error:", e)
54
+ return "[ASR Failed]"
55
+
56
+ # === Step 6: Text Generation Function ===
57
+ def generate_text(prompt):
58
+ try:
59
+ formatted_prompt = format_prompt(prompt)
60
+ response = client.text_generation(formatted_prompt, max_new_tokens=250, temperature=0.7, top_p=0.95)
61
+ return response.strip()
62
+ except Exception as e:
63
+ print("Text Generation Error:", e)
64
+ return "[Text Generation Failed]"
65
+
66
+ # === Step 7: Text-to-Speech Function ===
67
+ def text_to_speech(text):
68
+ lang = detect(text)
69
+ wav_path = "./output.wav"
70
+ try:
71
+ if lang == "sw":
72
+ swahili_tts.synthesis(text, wav_path=wav_path)
73
+ else:
74
+ english_tts.synthesis(text, wav_path=wav_path)
75
+ return wav_path
76
+ except Exception as e:
77
+ print("TTS Error:", e)
78
+ return "./error.wav" # Use fallback silent audio
79
+
80
+ # === Step 8: Combined Logic ===
81
+ def process_audio(audio):
82
+ transcription = transcribe(audio)
83
+ generated_text = generate_text(transcription)
84
+ speech = text_to_speech(generated_text)
85
+ print(f"[DEBUG] Transcription: {transcription}")
86
+ print(f"[DEBUG] Generated Text: {generated_text}")
87
+ print(f"[DEBUG] TTS Output Path: {speech} (type={type(speech)})")
88
+ return transcription, generated_text, speech
89
+
90
+ # === Step 9: Gradio Interface ===
91
+ with gr.Blocks() as demo:
92
+ gr.Markdown("<p align='center' style='font-size: 20px;'>End-to-End ASR β†’ Text Generation β†’ TTS</p>")
93
+ gr.HTML("<center>Upload or record audio. The model will transcribe, generate a response, and read it out.</center>")
94
+
95
+ audio_input = gr.Audio(label="πŸŽ™οΈ Input Audio", type="filepath")
96
+ text_output = gr.Textbox(label="πŸ“ Transcription")
97
+ generated_text_output = gr.Textbox(label="πŸ€– Generated Text")
98
+ audio_output = gr.Audio(label="πŸ”Š Output Speech")
99
+ submit_btn = gr.Button("Submit")
100
+
101
+ submit_btn.click(
102
+ fn=process_audio,
103
+ inputs=audio_input,
104
+ outputs=[text_output, generated_text_output, audio_output]
105
+ )
106
+
107
+ if __name__ == "__main__":
108
+ demo.launch()