Spaces:
Sleeping
Sleeping
| # remote_whisper.py | |
| import sys | |
| import time | |
| import logging | |
| import os | |
| import requests | |
| import json | |
| import base64 | |
| import numpy as np | |
| import soundfile as sf | |
| import io | |
| # Import the necessary components from whisper_online.py | |
| from libs.whisper_streaming.whisper_online import ( | |
| ASRBase, | |
| OnlineASRProcessor, | |
| VACOnlineASRProcessor, | |
| add_shared_args, | |
| asr_factory as original_asr_factory, | |
| set_logging, | |
| create_tokenizer, | |
| load_audio, | |
| load_audio_chunk, OpenaiApiASR, | |
| ) | |
| from model import dict_to_segment | |
| logger = logging.getLogger(__name__) | |
| # Define the RemoteFasterWhisperASR class | |
| class RemoteFasterWhisperASR(ASRBase): | |
| """Uses a remote FasterWhisper model via WebSocket.""" | |
| sep = "" # Same as FasterWhisperASR | |
| def load_model(self, *args, **kwargs): | |
| import websocket | |
| self.ws = websocket.WebSocket() | |
| # Replace with your server address | |
| server_address = "ws://localhost:8000/ws_transcribe_streaming" # Update with the actual server address | |
| self.ws.connect(server_address) | |
| logger.info(f"Connected to remote ASR server at {server_address}") | |
| def transcribe(self, audio, init_prompt=""): | |
| # Convert audio data to WAV bytes | |
| if isinstance(audio, str): | |
| # If audio is a filename, read the file | |
| with open(audio, 'rb') as f: | |
| audio_bytes = f.read() | |
| elif isinstance(audio, np.ndarray): | |
| # Write audio data to a buffer in WAV format | |
| audio_bytes_io = io.BytesIO() | |
| sf.write(audio_bytes_io, audio, samplerate=16000, format='WAV', subtype='PCM_16') | |
| audio_bytes = audio_bytes_io.getvalue() | |
| else: | |
| raise ValueError("Unsupported audio input type") | |
| # Encode to base64 | |
| audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') | |
| data = { | |
| 'audio': audio_b64, | |
| 'init_prompt': init_prompt | |
| } | |
| self.ws.send(json.dumps(data)) | |
| response = self.ws.recv() | |
| segments = json.loads(response) | |
| segments = [dict_to_segment(s) for s in segments] | |
| return segments | |
| def ts_words(self, segments): | |
| o = [] | |
| for segment in segments: | |
| for word in segment.words: | |
| if segment.no_speech_prob > 0.9: | |
| continue | |
| # not stripping the spaces -- should not be merged with them! | |
| w = word.word | |
| t = (word.start, word.end, w) | |
| o.append(t) | |
| return o | |
| def segments_end_ts(self, res): | |
| return [s.end for s in res] | |
| def use_vad(self): | |
| self.transcribe_kargs["vad_filter"] = True | |
| def set_translate_task(self): | |
| self.transcribe_kargs["task"] = "translate" | |
| # Update asr_factory to include RemoteFasterWhisperASR | |
| def asr_factory(args, logfile=sys.stderr): | |
| """ | |
| Creates and configures an ASR and Online ASR Processor instance based on the specified backend and arguments. | |
| """ | |
| backend = args.backend | |
| if backend == "openai-api": | |
| logger.debug("Using OpenAI API.") | |
| asr = OpenaiApiASR(lan=args.lan) | |
| elif backend == "remote-faster-whisper": | |
| asr_cls = RemoteFasterWhisperASR | |
| else: | |
| # Use the original asr_factory for other backends | |
| return original_asr_factory(args, logfile) | |
| # For RemoteFasterWhisperASR | |
| t = time.time() | |
| logger.info(f"Initializing Remote Faster Whisper ASR for language '{args.lan}'...") | |
| asr = asr_cls(modelsize=args.model, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir) | |
| e = time.time() | |
| logger.info(f"Initialization done. It took {round(e - t, 2)} seconds.") | |
| # Apply common configurations | |
| if getattr(args, 'vad', False): # Checks if VAD argument is present and True | |
| logger.info("Setting VAD filter") | |
| asr.use_vad() | |
| language = args.lan | |
| if args.task == "translate": | |
| asr.set_translate_task() | |
| tgt_language = "en" # Whisper translates into English | |
| else: | |
| tgt_language = language # Whisper transcribes in this language | |
| # Create the tokenizer | |
| if args.buffer_trimming == "sentence": | |
| tokenizer = create_tokenizer(tgt_language) | |
| else: | |
| tokenizer = None | |
| # Create the OnlineASRProcessor | |
| if args.vac: | |
| online = VACOnlineASRProcessor( | |
| args.min_chunk_size, | |
| asr, | |
| tokenizer, | |
| logfile=logfile, | |
| buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec) | |
| ) | |
| else: | |
| online = OnlineASRProcessor( | |
| asr, | |
| tokenizer, | |
| logfile=logfile, | |
| buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec) | |
| ) | |
| return asr, online | |
| # Now, write the main function that uses RemoteFasterWhisperASR | |
| def main(): | |
| import argparse | |
| import sys | |
| import numpy as np | |
| import io | |
| import soundfile as sf | |
| # Download the audio file if not already present | |
| AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" | |
| audio_file_path = "test_hebrew.wav" | |
| if not os.path.exists(audio_file_path): | |
| response = requests.get(AUDIO_FILE_URL) | |
| with open(audio_file_path, 'wb') as f: | |
| f.write(response.content) | |
| # Set up arguments | |
| class Args: | |
| def __init__(self): | |
| self.audio_path = audio_file_path | |
| self.lan = 'he' | |
| self.model = None # Not used in RemoteFasterWhisperASR | |
| self.model_cache_dir = None | |
| self.model_dir = None | |
| self.backend = 'remote-faster-whisper' | |
| self.task = 'transcribe' | |
| self.vad = False | |
| self.vac = True # Use VAC as default | |
| self.buffer_trimming = 'segment' | |
| self.buffer_trimming_sec = 15 | |
| self.min_chunk_size = 1.0 | |
| self.vac_chunk_size = 0.04 | |
| self.start_at = 0.0 | |
| self.offline = False | |
| self.comp_unaware = False | |
| self.log_level = 'DEBUG' | |
| args = Args() | |
| # Set up logging | |
| set_logging(args, logger) | |
| audio_path = args.audio_path | |
| SAMPLING_RATE = 16000 | |
| duration = len(load_audio(audio_path)) / SAMPLING_RATE | |
| logger.info("Audio duration is: %2.2f seconds" % duration) | |
| asr, online = asr_factory(args, logfile=sys.stderr) | |
| if args.vac: | |
| min_chunk = args.vac_chunk_size | |
| else: | |
| min_chunk = args.min_chunk_size | |
| # Load the audio into the LRU cache before we start the timer | |
| a = load_audio_chunk(audio_path, 0, 1) | |
| # Warm up the ASR because the very first transcribe takes more time | |
| asr.transcribe(a) | |
| beg = args.start_at | |
| start = time.time() - beg | |
| def output_transcript(o, now=None): | |
| # Output format in stdout is like: | |
| # 4186.3606 0 1720 Takhle to je | |
| # - The first three numbers are: | |
| # - Emission time from the beginning of processing, in milliseconds | |
| # - Begin and end timestamp of the text segment, as estimated by Whisper model | |
| # - The next words: segment transcript | |
| if now is None: | |
| now = time.time() - start | |
| if o[0] is not None: | |
| print("%1.4f %1.0f %1.0f %s" % (now * 1000, o[0] * 1000, o[1] * 1000, o[2]), flush=True) | |
| else: | |
| # No text, so no output | |
| pass | |
| end = 0 | |
| while True: | |
| now = time.time() - start | |
| if now < end + min_chunk: | |
| time.sleep(min_chunk + end - now) | |
| end = time.time() - start | |
| a = load_audio_chunk(audio_path, beg, end) | |
| beg = end | |
| online.insert_audio_chunk(a) | |
| try: | |
| o = online.process_iter() | |
| except AssertionError as e: | |
| logger.error(f"Assertion error: {e}") | |
| pass | |
| else: | |
| output_transcript(o) | |
| now = time.time() - start | |
| logger.debug(f"## Last processed {end:.2f} s, now is {now:.2f}, latency is {now - end:.2f}") | |
| if end >= duration: | |
| break | |
| now = None | |
| o = online.finish() | |
| output_transcript(o, now=now) | |
| if __name__ == "__main__": | |
| main() | |