Michael Hu commited on
Commit
b591083
·
1 Parent(s): 6f92dbc

refactor(stt): replace whisper with faster-whisper for improved performance

Browse files

Switch from transformers-based whisper implementation to faster-whisper for better speed and memory efficiency. The new implementation removes torch dependency for device detection and uses optimized compute types based on available hardware.

Files changed (1) hide show
  1. utils/stt.py +38 -46
utils/stt.py CHANGED
@@ -11,10 +11,8 @@ from abc import ABC, abstractmethod
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
- import torch
15
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
16
  from pydub import AudioSegment
17
- import soundfile as sf
18
 
19
  class ASRModel(ABC):
20
  """Base class for ASR models"""
@@ -43,64 +41,58 @@ class ASRModel(ABC):
43
 
44
 
45
  class WhisperModel(ASRModel):
46
- """Whisper ASR model implementation"""
47
 
48
  def __init__(self):
49
  self.model = None
50
- self.processor = None
51
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
52
 
53
  def load_model(self):
54
- """Load Whisper model"""
55
- logger.info("Loading Whisper model")
56
  logger.info(f"Using device: {self.device}")
 
57
 
58
- self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
59
- "openai/whisper-large-v3",
60
- torch_dtype=torch.float32,
61
- low_cpu_mem_usage=True,
62
- use_safetensors=True
63
- ).to(self.device)
64
-
65
- self.processor = AutoProcessor.from_pretrained("unsloth/whisper-large-v3")
66
- logger.info("Whisper model loaded successfully")
67
 
68
  def transcribe(self, audio_path):
69
- """Transcribe audio using Whisper"""
70
- if self.model is None or self.processor is None:
71
  self.load_model()
72
 
73
  wav_path = self.preprocess_audio(audio_path)
74
 
75
- # Processing
76
- logger.info("Processing audio input")
77
- logger.debug("Loading audio data")
78
- audio_data, sample_rate = sf.read(wav_path)
79
- audio_data = audio_data.astype(np.float32)
 
 
 
80
 
81
- # Increase chunk length and stride for longer transcriptions
82
- inputs = self.processor(
83
- audio_data,
84
- sampling_rate=16000,
85
- return_tensors="pt",
86
- # Increase chunk length to handle longer segments
87
- chunk_length_s=60,
88
- stride_length_s=10
89
- ).to(self.device)
90
-
91
- # Transcription
92
- logger.info("Generating transcription")
93
- with torch.no_grad():
94
- # Add max_length parameter to allow for longer outputs
95
- outputs = self.model.generate(
96
- **inputs,
97
- language="en",
98
- task="transcribe",
99
- max_length=448, # Explicitly set max output length
100
- no_repeat_ngram_size=3 # Prevent repetition in output
101
- )
102
 
103
- result = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
104
  logger.info(f"Transcription completed successfully")
105
  return result
106
 
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
+ from faster_whisper import WhisperModel as FasterWhisperModel
 
15
  from pydub import AudioSegment
 
16
 
17
  class ASRModel(ABC):
18
  """Base class for ASR models"""
 
41
 
42
 
43
  class WhisperModel(ASRModel):
44
+ """Faster Whisper ASR model implementation"""
45
 
46
  def __init__(self):
47
  self.model = None
48
+ # Check for CUDA availability without torch dependency
49
+ try:
50
+ import torch
51
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ except ImportError:
53
+ # Fallback to CPU if torch is not available
54
+ self.device = "cpu"
55
+ self.compute_type = "float16" if self.device == "cuda" else "int8"
56
 
57
  def load_model(self):
58
+ """Load Faster Whisper model"""
59
+ logger.info("Loading Faster Whisper model")
60
  logger.info(f"Using device: {self.device}")
61
+ logger.info(f"Using compute type: {self.compute_type}")
62
 
63
+ # Use large-v3 model with appropriate compute type based on device
64
+ self.model = FasterWhisperModel(
65
+ "large-v3",
66
+ device=self.device,
67
+ compute_type=self.compute_type
68
+ )
69
+ logger.info("Faster Whisper model loaded successfully")
 
 
70
 
71
  def transcribe(self, audio_path):
72
+ """Transcribe audio using Faster Whisper"""
73
+ if self.model is None:
74
  self.load_model()
75
 
76
  wav_path = self.preprocess_audio(audio_path)
77
 
78
+ # Transcription with Faster Whisper
79
+ logger.info("Generating transcription with Faster Whisper")
80
+ segments, info = self.model.transcribe(
81
+ wav_path,
82
+ beam_size=5,
83
+ language="en",
84
+ task="transcribe"
85
+ )
86
 
87
+ logger.info(f"Detected language '{info.language}' with probability {info.language_probability}")
88
+
89
+ # Collect all segments into a single text
90
+ result_text = ""
91
+ for segment in segments:
92
+ result_text += segment.text + " "
93
+ logger.debug(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ result = result_text.strip()
96
  logger.info(f"Transcription completed successfully")
97
  return result
98