Luigi commited on
Commit
019d245
·
1 Parent(s): ee1376a

fix model_type error

Browse files
Files changed (1) hide show
  1. app.py +55 -64
app.py CHANGED
@@ -1,9 +1,10 @@
1
- # app.py
2
- import spaces
3
  import re
4
  import torch
5
  import gradio as gr
 
6
  from transformers import pipeline
 
 
7
 
8
  # List of Whisper model IDs
9
  WHISPER_MODELS = [
@@ -13,30 +14,7 @@ WHISPER_MODELS = [
13
  "openai/whisper-small",
14
  "openai/whisper-medium",
15
  "openai/whisper-base",
16
- "JacobLinCool/whisper-large-v3-turbo-common_voice_19_0-zh-TW",
17
- "Jingmiao/whisper-small-zh_tw",
18
- "DDTChen/whisper-medium-zh-tw",
19
- "kimbochen/whisper-small-zh-tw",
20
- "ChrisTorng/whisper-large-v3-turbo-common_voice_19_0-zh-TW-ct2",
21
- "JacobLinCool/whisper-large-v3-turbo-zh-TW-clean-1",
22
- "JunWorks/whisper-small-zhTW",
23
- "WANGTINGTING/whisper-large-v2-zh-TW-vol2",
24
- "xmzhu/whisper-tiny-zh-TW",
25
- "ingrenn/whisper-small-common-voice-13-zh-TW",
26
- "jun-han/whisper-small-zh-TW",
27
- "xmzhu/whisper-tiny-zh-TW-baseline",
28
- "JacobLinCool/whisper-large-v3-turbo-common_voice_16_1-zh-TW-2",
29
- "JacobLinCool/whisper-large-v3-common_voice_19_0-zh-TW-full-1",
30
- "momo103197/whisper-small-zh-TW-mix",
31
- "JacobLinCool/whisper-large-v3-turbo-zh-TW-clean-1-merged",
32
- "JacobLinCool/whisper-large-v2-common_voice_19_0-zh-TW-full-1",
33
- "kimas1269/whisper-meduim_zhtw",
34
- "JunWorks/whisper-base-zhTW",
35
- "JunWorks/whisper-small-zhTW-frozenDecoder",
36
- "sandy1990418/whisper-large-v3-turbo-zh-tw",
37
- "JacobLinCool/whisper-large-v3-turbo-common_voice_16_1-zh-TW-pissa-merged",
38
- "momo103197/whisper-small-zh-TW-16",
39
- "k1nto/Belle-whisper-large-v3-zh-punct-ct2"
40
  ]
41
 
42
  # List of SenseVoice model IDs
@@ -48,62 +26,75 @@ SENSEVOICE_MODELS = [
48
  "apinge/sensevoice-small"
49
  ]
50
 
51
- # Cache pipelines
52
- pipes = {}
 
 
53
 
54
- def get_asr_pipe(model_id):
55
- if model_id not in pipes:
56
- # run on GPU if available
57
  device = 0 if torch.cuda.is_available() else -1
58
- pipes[model_id] = pipeline("automatic-speech-recognition", model=model_id, device=device)
59
- return pipes[model_id]
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  @spaces.GPU
62
  def transcribe(whisper_model, sense_model, audio_path, enable_punct):
63
- # 1) Whisper
64
- whisper_pipe = get_asr_pipe(whisper_model)
65
- whisper_out = whisper_pipe(audio_path)
66
- text_whisper = whisper_out.get("text", "").strip()
67
 
68
- # 2) SenseVoice
69
- sense_pipe = get_asr_pipe(sense_model)
70
- sense_out = sense_pipe(audio_path)
71
- text_sense = sense_out.get("text", "").strip()
 
 
 
 
 
 
 
 
72
 
73
- # 3) strip punctuation if disabled
74
  if not enable_punct:
75
  text_sense = re.sub(r"[^\w\s]", "", text_sense)
76
 
77
  return text_whisper, text_sense
78
 
79
- with gr.Blocks() as demo:
80
- gr.Markdown("## Whisper vs. FunASR SenseVoice Comparison")
 
 
81
  with gr.Row():
82
- whisper_dd = gr.Dropdown(
83
- choices=WHISPER_MODELS,
84
- value=WHISPER_MODELS[0],
85
- label="Whisper Model"
86
- )
87
- sense_dd = gr.Dropdown(
88
- choices=SENSEVOICE_MODELS,
89
- value=SENSEVOICE_MODELS[0],
90
- label="SenseVoice Model"
91
- )
92
  punct = gr.Checkbox(label="Enable Punctuation (SenseVoice)", value=True)
93
- audio_in = gr.Audio(
94
- sources=["upload","microphone"],
95
- type="filepath",
96
- label="Upload or Record Audio"
97
- )
98
  with gr.Row():
99
  out_whisper = gr.Textbox(label="Whisper Transcript")
100
  out_sense = gr.Textbox(label="SenseVoice Transcript")
101
  btn = gr.Button("Transcribe")
102
- btn.click(
103
- fn=transcribe,
104
- inputs=[whisper_dd, sense_dd, audio_in, punct],
105
- outputs=[out_whisper, out_sense]
106
- )
107
 
108
  if __name__ == "__main__":
109
  demo.launch()
 
 
 
1
  import re
2
  import torch
3
  import gradio as gr
4
+ import spaces # zeroGPU support
5
  from transformers import pipeline
6
+ from funasr import AutoModel
7
+ from funasr.utils.postprocess_utils import rich_transcription_postprocess
8
 
9
  # List of Whisper model IDs
10
  WHISPER_MODELS = [
 
14
  "openai/whisper-small",
15
  "openai/whisper-medium",
16
  "openai/whisper-base",
17
+ # ... additional multilingual Whisper variants
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  ]
19
 
20
  # List of SenseVoice model IDs
 
26
  "apinge/sensevoice-small"
27
  ]
28
 
29
+ # Cache Whisper pipelines
30
+ whisper_pipes = {}
31
+ # Cache SenseVoice models
32
+ sense_models = {}
33
 
34
+ def get_whisper_pipe(model_id):
35
+ if model_id not in whisper_pipes:
 
36
  device = 0 if torch.cuda.is_available() else -1
37
+ whisper_pipes[model_id] = pipeline(
38
+ "automatic-speech-recognition",
39
+ model=model_id,
40
+ device=device
41
+ )
42
+ return whisper_pipes[model_id]
43
+
44
 
45
+ def get_sense_model(model_id):
46
+ if model_id not in sense_models:
47
+ device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
48
+ sense_models[model_id] = AutoModel(
49
+ model=model_id,
50
+ vad_model="fsmn-vad",
51
+ vad_kwargs={"max_single_segment_time": 30000},
52
+ device=device_str,
53
+ hub="hf",
54
+ )
55
+ return sense_models[model_id]
56
+
57
+ # Decorate with @spaces.GPU to allocate GPU only during transcription
58
  @spaces.GPU
59
  def transcribe(whisper_model, sense_model, audio_path, enable_punct):
60
+ # Whisper transcription
61
+ pipe = get_whisper_pipe(whisper_model)
62
+ out = pipe(audio_path)
63
+ text_whisper = out.get("text", "").strip()
64
 
65
+ # SenseVoice transcription using FunASR
66
+ model = get_sense_model(sense_model)
67
+ res = model.generate(
68
+ input=audio_path,
69
+ cache={},
70
+ language="auto",
71
+ use_itn=True,
72
+ batch_size_s=60,
73
+ merge_vad=True,
74
+ merge_length_s=15,
75
+ )
76
+ text_sense = rich_transcription_postprocess(res[0]["text"]) # apply punctuation/normalization
77
 
78
+ # Strip punctuation if disabled
79
  if not enable_punct:
80
  text_sense = re.sub(r"[^\w\s]", "", text_sense)
81
 
82
  return text_whisper, text_sense
83
 
84
+ # Gradio UI setup
85
+ demo = gr.Blocks()
86
+ with demo:
87
+ gr.Markdown("## Whisper vs. FunASR SenseVoice Comparison (ZeroGPU Enabled)")
88
  with gr.Row():
89
+ whisper_dd = gr.Dropdown(choices=WHISPER_MODELS, value=WHISPER_MODELS[0], label="Whisper Model")
90
+ sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model")
 
 
 
 
 
 
 
 
91
  punct = gr.Checkbox(label="Enable Punctuation (SenseVoice)", value=True)
92
+ audio_input = gr.Audio(source="upload+microphone", type="filepath", label="Upload or Record Audio")
 
 
 
 
93
  with gr.Row():
94
  out_whisper = gr.Textbox(label="Whisper Transcript")
95
  out_sense = gr.Textbox(label="SenseVoice Transcript")
96
  btn = gr.Button("Transcribe")
97
+ btn.click(fn=transcribe, inputs=[whisper_dd, sense_dd, audio_input, punct], outputs=[out_whisper, out_sense])
 
 
 
 
98
 
99
  if __name__ == "__main__":
100
  demo.launch()