xujinheng666 commited on
Commit
0eb093e
·
verified ·
1 Parent(s): 6589673

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -51
app.py CHANGED
@@ -3,69 +3,80 @@ import torch
3
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torchaudio
5
  import os
 
6
 
7
- def load_models():
8
- st.session_state.transcription_pipe = pipeline(
9
- task="automatic-speech-recognition",
10
- model="alvanlii/whisper-small-cantonese",
11
- chunk_length_s=60,
12
- device="cuda" if torch.cuda.is_available() else "cpu"
13
- )
14
- st.session_state.transcription_pipe.model.config.forced_decoder_ids = st.session_state.transcription_pipe.tokenizer.get_decoder_prompt_ids(language="zh", task="transcribe")
15
-
16
- st.session_state.translation_tokenizer = AutoTokenizer.from_pretrained("botisan-ai/mt5-translate-yue-zh")
17
- st.session_state.translation_model = AutoModelForSeq2SeqLM.from_pretrained("botisan-ai/mt5-translate-yue-zh")
18
 
19
- st.session_state.rating_pipe = pipeline("sentiment-analysis", model="uer/roberta-base-finetuned-dianping-chinese")
20
- # st.session_state.rating_pipe = pipeline("text-classification", model="jackietung/bert-base-chinese-finetuned-sentiment")
 
 
 
21
 
22
  def transcribe_audio(audio_path):
23
- pipe = st.session_state.transcription_pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  return pipe(audio_path)["text"]
25
 
26
- def translate_text(text):
27
- tokenizer = st.session_state.translation_tokenizer
28
- model = st.session_state.translation_model
29
- inputs = tokenizer(text, return_tensors="pt")
30
- outputs = model.generate(inputs["input_ids"], max_length=1000, num_beams=5)
31
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def rate_quality(text):
34
- result = st.session_state.rating_pipe(text)[0]
35
  label = result["label"].split("(")[0].strip().lower()
36
- # label = result["label"]
37
  label_map = {"negative": "Poor", "neutral": "Average", "positive": "Good"}
38
  return label_map.get(label, "Unknown")
39
 
40
- def main():
41
- st.title("Audio Processing & Conversation Quality Rating")
 
 
 
 
 
 
 
 
 
42
 
43
- if "transcription_pipe" not in st.session_state:
44
- with st.spinner("Loading models..."):
45
- load_models()
46
 
47
- uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "m4a"])
 
48
 
49
- if uploaded_file is not None:
50
- with st.spinner("Processing audio..."):
51
- file_path = "temp_audio.wav"
52
- with open(file_path, "wb") as f:
53
- f.write(uploaded_file.read())
54
-
55
- transcript = transcribe_audio(file_path)
56
- translation = translate_text(transcript)
57
- rating = rate_quality(translation)
58
-
59
- os.remove(file_path)
60
-
61
- st.subheader("Transcription")
62
- st.write(transcript)
63
-
64
- st.subheader("Translation (Chinese)")
65
- st.write(translation)
66
-
67
- st.subheader("Conversation Quality Rating")
68
- st.write(rating)
69
-
70
- if __name__ == "__main__":
71
- main()
 
3
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torchaudio
5
  import os
6
+ import re
7
 
8
+ # Device setup
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
10
 
11
+ # Load Whisper model for transcription
12
+ MODEL_NAME = "alvanlii/whisper-small-cantonese"
13
+ language = "zh"
14
+ pipe = pipeline(task="automatic-speech-recognition", model=MODEL_NAME, chunk_length_s=60, device=device)
15
+ pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe")
16
 
17
  def transcribe_audio(audio_path):
18
+ waveform, sample_rate = torchaudio.load(audio_path)
19
+ duration = waveform.shape[1] / sample_rate
20
+
21
+ if duration > 60:
22
+ results = []
23
+ for start in range(0, int(duration), 50):
24
+ end = min(start + 60, int(duration)
25
+ chunk = waveform[:, start * sample_rate:end * sample_rate]
26
+ temp_filename = f"temp_chunk_{start}.wav"
27
+ torchaudio.save(temp_filename, chunk, sample_rate)
28
+ result = pipe(temp_filename)["text"]
29
+ results.append(result)
30
+ os.remove(temp_filename)
31
+ return " ".join(results)
32
+
33
  return pipe(audio_path)["text"]
34
 
35
+ # Load translation model
36
+ tokenizer = AutoTokenizer.from_pretrained("botisan-ai/mt5-translate-yue-zh")
37
+ model = AutoModelForSeq2SeqLM.from_pretrained("botisan-ai/mt5-translate-yue-zh").to(device)
38
+
39
+ def split_sentences(text):
40
+ return [s for s in re.split(r'(?<=[。!?])', text) if s]
41
+
42
+ def translate(text):
43
+ sentences = split_sentences(text)
44
+ translations = []
45
+ for sentence in sentences:
46
+ inputs = tokenizer(sentence, return_tensors="pt").to(device)
47
+ outputs = model.generate(inputs["input_ids"], max_length=1000, num_beams=5)
48
+ translations.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
49
+ return " ".join(translations)
50
+
51
+ # Load sentiment analysis model
52
+ rating_pipe = pipeline("sentiment-analysis", model="uer/roberta-base-finetuned-dianping-chinese")
53
 
54
  def rate_quality(text):
55
+ result = rating_pipe(text)[0]
56
  label = result["label"].split("(")[0].strip().lower()
 
57
  label_map = {"negative": "Poor", "neutral": "Average", "positive": "Good"}
58
  return label_map.get(label, "Unknown")
59
 
60
+ # Streamlit UI
61
+ st.title("Cantonese Audio Analysis")
62
+ st.write("Upload a Cantonese audio file to transcribe, translate, and rate the conversation quality.")
63
+
64
+ uploaded_file = st.file_uploader("Upload Audio File", type=["wav", "mp3", "flac"])
65
+
66
+ if uploaded_file is not None:
67
+ st.audio(uploaded_file, format="audio/wav")
68
+ temp_audio_path = "uploaded_audio.wav"
69
+ with open(temp_audio_path, "wb") as f:
70
+ f.write(uploaded_file.getbuffer())
71
 
72
+ st.write("### Processing...")
73
+ transcript = transcribe_audio(temp_audio_path)
74
+ st.write("**Transcript:**", transcript)
75
 
76
+ translated_text = translate(transcript)
77
+ st.write("**Translation:**", translated_text)
78
 
79
+ quality_rating = rate_quality(translated_text)
80
+ st.write("**Quality Rating:**", quality_rating)
81
+
82
+ os.remove(temp_audio_path)