Commit
·
ffa317c
0
Parent(s):
files upload
Browse files- .gitattributes +36 -0
- README.md +13 -0
- app.py +153 -0
- asr/CTC_model.py +110 -0
- asr/__init__.py +0 -0
- asr/run_asr.py +42 -0
- asr/vocab.json +1 -0
- nlu/run_nlu.py +85 -0
- requirements.txt +7 -0
- resources/audios/speech_massive_samples/ar_sa_sample_audio.wav +0 -0
- resources/audios/speech_massive_samples/de_de_sample_audio.wav +0 -0
- resources/audios/speech_massive_samples/es_es_sample_audio.wav +0 -0
- resources/audios/speech_massive_samples/fr_fr_sample_audio.wav +0 -0
- resources/audios/speech_massive_samples/hu_hu_sample_audio.wav +0 -0
- resources/audios/speech_massive_samples/ko_kr_sample_audio.wav +0 -0
- resources/audios/speech_massive_samples/nl_nl_sample_audio.wav +0 -0
- resources/audios/speech_massive_samples/pl_pl_sample_audio.wav.wav +0 -0
- resources/audios/speech_massive_samples/pt_pt_sample_audio.wav +0 -0
- resources/audios/speech_massive_samples/ru_ru_sample_audio.wav +0 -0
- resources/audios/speech_massive_samples/tr_tr_sample_audio.wav +0 -0
- resources/audios/speech_massive_samples/vi_vn_sample_audio.wav +0 -0
- resources/audios/utt_1264.wav +0 -0
- resources/audios/utt_14684.wav +0 -0
- resources/audios/utt_16032.wav +0 -0
- resources/audios/utt_2414.wav +0 -0
- resources/audios/utt_286.wav +0 -0
- resources/audios/utt_3060.wav +0 -0
- resources/audios/utt_5410.wav +0 -0
- resources/audios/utt_6162.wav +0 -0
- resources/audios/utt_9137.wav +0 -0
- resources/audios/utt_9912.wav +0 -0
- resources/logos/EU_flag.jpg +0 -0
- resources/logos/FBK_logo.png +0 -0
- resources/logos/NAVERLABS_2_BLACK.png +0 -0
- resources/logos/Utter_logo.png +0 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
nlu/model/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: French SLU DEMO Interspeech2024
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.41.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: cc-by-nc-sa-4.0
|
11 |
+
short_description: French SLU demo
|
12 |
+
---
|
13 |
+
|
app.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import librosa
|
4 |
+
|
5 |
+
from asr.run_asr import run_asr_inference, load_asr_model
|
6 |
+
from nlu.run_nlu import run_nlu_inference, load_nlu_model
|
7 |
+
|
8 |
+
############### strings
|
9 |
+
mhubert_link = '[mHuBERT-147 model](https://huggingface.co/utter-project/mHuBERT-147)'
|
10 |
+
massive_link = '[Speech-MASSIVE dataset](https://huggingface.co/datasets/FBK-MT/Speech-MASSIVE)'
|
11 |
+
blog_post_link = '[blog post](TO DO TO DO)'
|
12 |
+
title = "# DEMO: French Spoken Language Understanding using mHuBERT-147 and Speech-MASSIVE"
|
13 |
+
description=[
|
14 |
+
f"""
|
15 |
+
**Interspeech 2024 DEMO.** Cascaded SLU using {mhubert_link} and {massive_link} components.
|
16 |
+
""",
|
17 |
+
"""For more details on the implementation, check our blog post.""",
|
18 |
+
]
|
19 |
+
|
20 |
+
examples = [
|
21 |
+
"resources/audios/utt_286.wav",
|
22 |
+
"resources/audios/utt_2414.wav",
|
23 |
+
"resources/audios/utt_16032.wav",
|
24 |
+
"resources/audios/utt_3060.wav",
|
25 |
+
"resources/audios/utt_1264.wav",
|
26 |
+
"resources/audios/utt_9912.wav",
|
27 |
+
"resources/audios/utt_14684.wav",
|
28 |
+
"resources/audios/utt_5410.wav",
|
29 |
+
]
|
30 |
+
transcriptions = [
|
31 |
+
"allume les lumières dans la cuisine",
|
32 |
+
"je veux commander une pizza chez michael's pizza",
|
33 |
+
"veuillez envoyer un e-mail à sally concernant la réunion de demain",
|
34 |
+
"quelles sont les nouvelles de newsource",
|
35 |
+
"mon réveil est-il réglé pour demain matin",
|
36 |
+
"olly combien de temps dois-je faire bouillir les oeufs",
|
37 |
+
"qui est le premier ministre de russie",
|
38 |
+
"lis moi les derniers gros titres du new york times"
|
39 |
+
]
|
40 |
+
intents = [
|
41 |
+
"iot_hue_lighton",
|
42 |
+
"takeaway_order",
|
43 |
+
"email_sendemail",
|
44 |
+
"news_query",
|
45 |
+
"alarm_query",
|
46 |
+
"cooking_recipe",
|
47 |
+
"qa_factoid",
|
48 |
+
"news_query"
|
49 |
+
]
|
50 |
+
slots = [
|
51 |
+
[ "Other", "Other", "Other", "Other", "Other", "house_place" ],
|
52 |
+
[ "Other", "Other", "Other", "Other", "food_type", "Other", "business_name", "business_name" ],
|
53 |
+
[ 'Other', 'Other', 'Other', 'Other', 'Other', 'Other', 'person', 'Other', 'Other', 'event_name', 'Other', 'date'],
|
54 |
+
[ 'Other', 'Other', 'Other', 'Other', 'Other', 'media_type'],
|
55 |
+
[ 'Other', 'Other', 'Other', 'Other', 'Other', 'Other', 'date', 'timeofday'],
|
56 |
+
[ 'Other', 'Other', 'Other', 'Other', 'Other', 'Other', 'Other', 'cooking_type', 'Other', 'food_type'],
|
57 |
+
[ 'Other', 'Other', 'Other', 'Other', 'Other', 'Other', 'place_name'],
|
58 |
+
[ 'Other', 'Other', 'Other', 'Other', 'Other', 'Other', 'Other', 'media_type', 'media_type', 'media_type']
|
59 |
+
]
|
60 |
+
|
61 |
+
|
62 |
+
utter_ack_text = """This is an output of the European Project UTTER (Unified Transcription and Translation for Extended Reality) funded by European Union’s Horizon Europe Research and Innovation programme under grant agreement number 101070631.
|
63 |
+
For more information please visit https://he-utter.eu/"""
|
64 |
+
|
65 |
+
ack_authors = """This demo was made by [Beomseok Lee](https://mt.fbk.eu/author/blee/) and [Marcely Zanon Boito](https://sites.google.com/view/mzboito/marcely-zanon-boito)."""
|
66 |
+
|
67 |
+
eu_logo = """<img src="https://huggingface.co/spaces/naver/French-SLU-DEMO-Interspeech2024/resolve/main/resources/logos/EU_flag.jpg" width="100" height="100">"""
|
68 |
+
utter_logo = """<a href="https://he-utter.eu/" target="_blank"><img src="https://huggingface.co/spaces/naver/French-SLU-DEMO-Interspeech2024/resolve/main/resources/logos/Utter_logo.png" width="50" height="50"></a>"""
|
69 |
+
nle_logo = """<a href="https://europe.naverlabs.com/" target="_blank"><img src="https://huggingface.co/spaces/naver/French-SLU-DEMO-Interspeech2024/resolve/main/resources/logos/NAVERLABS_2_BLACK.png" width="100" height="100"></a>"""
|
70 |
+
fbk_logo = """<a href="https://mt.fbk.eu/" target="_blank"><img src="https://huggingface.co/spaces/naver/French-SLU-DEMO-Interspeech2024/resolve/main/resources/logos/FBK_logo.png" width="100" height="100"></a>"""
|
71 |
+
|
72 |
+
|
73 |
+
table = f"""
|
74 |
+
| File | Transcription | Slots | Intent |
|
75 |
+
| ------------ | ------------------- | ---------- | -----------|
|
76 |
+
| {examples[0].split("/")[-1]} | {transcriptions[0]} | {slots[0]} | {intents[0]} |
|
77 |
+
| {examples[1].split("/")[-1]} | {transcriptions[1]} | {slots[1]} | {intents[1]} |
|
78 |
+
| {examples[2].split("/")[-1]} | {transcriptions[2]} | {slots[2]} | {intents[2]} |
|
79 |
+
| {examples[3].split("/")[-1]} | {transcriptions[3]} | {slots[3]} | {intents[3]} |
|
80 |
+
| {examples[4].split("/")[-1]} | {transcriptions[4]} | {slots[4]} | {intents[4]} |
|
81 |
+
| {examples[5].split("/")[-1]} | {transcriptions[5]} | {slots[5]} | {intents[5]} |
|
82 |
+
| {examples[6].split("/")[-1]} | {transcriptions[6]} | {slots[6]} | {intents[6]} |
|
83 |
+
| {examples[7].split("/")[-1]} | {transcriptions[7]} | {slots[7]} | {intents[7]} |
|
84 |
+
""".strip()
|
85 |
+
|
86 |
+
############### calls
|
87 |
+
|
88 |
+
def run_inference(audio_file):
|
89 |
+
print(audio_file)
|
90 |
+
audio_struct = librosa.load(audio_file, sr=16000)
|
91 |
+
print(audio_struct)
|
92 |
+
audio = {'sampling_rate': audio_struct[1], 'array': audio_struct[0]} #.astype(np.float32)
|
93 |
+
transcription = run_asr_inference(asr_model, processor, audio)
|
94 |
+
print(transcription)
|
95 |
+
structured_output = run_nlu_inference(nlu_model, tokenizer, transcription)
|
96 |
+
|
97 |
+
return structured_output
|
98 |
+
|
99 |
+
############### app
|
100 |
+
|
101 |
+
asr_model, processor = load_asr_model()
|
102 |
+
nlu_model, tokenizer = load_nlu_model()
|
103 |
+
|
104 |
+
demo = gr.Blocks(
|
105 |
+
title=title,
|
106 |
+
analytics_enabled=False,
|
107 |
+
theme=gr.themes.Base(),
|
108 |
+
)
|
109 |
+
|
110 |
+
with demo:
|
111 |
+
gr.Markdown(title)
|
112 |
+
for line in description:
|
113 |
+
gr.Markdown(line)
|
114 |
+
|
115 |
+
with gr.Row():
|
116 |
+
waveform_options = gr.WaveformOptions(sample_rate=16000)
|
117 |
+
|
118 |
+
audio_file = gr.Audio(
|
119 |
+
label="Audio file",
|
120 |
+
sources=['microphone','upload'],
|
121 |
+
type="filepath",
|
122 |
+
format='wav',
|
123 |
+
waveform_options=waveform_options,
|
124 |
+
show_download_button=False,
|
125 |
+
show_share_button=False,
|
126 |
+
max_length=20,
|
127 |
+
)
|
128 |
+
|
129 |
+
output = gr.HighlightedText(label="ASR result + NLU result")
|
130 |
+
|
131 |
+
gr.Button("Run Inference", variant='primary').click(
|
132 |
+
run_inference,
|
133 |
+
concurrency_limit=2,
|
134 |
+
inputs=audio_file,
|
135 |
+
outputs=output,
|
136 |
+
)
|
137 |
+
|
138 |
+
with gr.Row():
|
139 |
+
gr.Examples(label="Speech-MASSIVE test utterances:", inputs=audio_file, examples=examples)
|
140 |
+
gr.Markdown(table)
|
141 |
+
|
142 |
+
gr.Markdown("# Aknowledgments")
|
143 |
+
gr.Markdown(utter_ack_text)
|
144 |
+
gr.Markdown(ack_authors)
|
145 |
+
|
146 |
+
with gr.Row():
|
147 |
+
gr.Markdown(eu_logo)
|
148 |
+
gr.Markdown(utter_logo)
|
149 |
+
gr.Markdown(nle_logo)
|
150 |
+
gr.Markdown(fbk_logo)
|
151 |
+
|
152 |
+
demo.queue()
|
153 |
+
demo.launch()
|
asr/CTC_model.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Inference CTC class derived from HubertForCTC.
|
3 |
+
|
4 |
+
Author: Marcely Zanon Boito, 2024
|
5 |
+
"""
|
6 |
+
|
7 |
+
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
from transformers import HubertPreTrainedModel, HubertModel
|
12 |
+
from transformers.modeling_outputs import CausalLMOutput, SequenceClassifierOutput
|
13 |
+
|
14 |
+
class VanillaNN(nn.Module):
|
15 |
+
def __init__(self, input_dim, output_dim):
|
16 |
+
"""
|
17 |
+
simple NN with ReLU activation (no norm)
|
18 |
+
"""
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
self.linear = nn.Linear(input_dim, output_dim)
|
22 |
+
self.act_fn = nn.ReLU()
|
23 |
+
|
24 |
+
def forward(self, hidden_states: torch.FloatTensor):
|
25 |
+
hidden_states = self.linear(hidden_states)
|
26 |
+
hidden_states = self.act_fn(hidden_states)
|
27 |
+
|
28 |
+
return hidden_states
|
29 |
+
|
30 |
+
class mHubertForCTC(HubertPreTrainedModel):
|
31 |
+
def __init__(self, config, target_lang: Optional[str] = None):
|
32 |
+
super().__init__(config)
|
33 |
+
self.hubert = HubertModel(config)
|
34 |
+
self.dropout = nn.Dropout(config.final_dropout)
|
35 |
+
|
36 |
+
output_hidden_size = config.hidden_size
|
37 |
+
|
38 |
+
self.has_interface = config.add_interface
|
39 |
+
|
40 |
+
# NN layers on top of the trainable stack
|
41 |
+
if config.add_interface:
|
42 |
+
self.interface = nn.ModuleList([VanillaNN(output_hidden_size,output_hidden_size) for i in range(config.num_interface_layers)])
|
43 |
+
|
44 |
+
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
45 |
+
|
46 |
+
self.post_init()
|
47 |
+
|
48 |
+
def forward(
|
49 |
+
self,
|
50 |
+
input_values: Optional[torch.Tensor],
|
51 |
+
attention_mask: Optional[torch.Tensor] = None,
|
52 |
+
output_attentions: Optional[bool] = None,
|
53 |
+
output_hidden_states: Optional[bool] = None,
|
54 |
+
return_dict: Optional[bool] = None,
|
55 |
+
labels: Optional[torch.Tensor] = None,
|
56 |
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
57 |
+
|
58 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
59 |
+
output_hidden_states = self.config.output_hidden_states
|
60 |
+
|
61 |
+
outputs = self.hubert(
|
62 |
+
input_values,
|
63 |
+
attention_mask=attention_mask,
|
64 |
+
output_attentions=output_attentions,
|
65 |
+
output_hidden_states=output_hidden_states,
|
66 |
+
return_dict=return_dict,
|
67 |
+
)
|
68 |
+
hidden_states = outputs[0]
|
69 |
+
|
70 |
+
hidden_states = self.dropout(hidden_states)
|
71 |
+
if self.has_interface:
|
72 |
+
for layer in self.interface:
|
73 |
+
hidden_states = layer(hidden_states)
|
74 |
+
logits = self.lm_head(hidden_states)
|
75 |
+
|
76 |
+
loss = None
|
77 |
+
|
78 |
+
if labels is not None:
|
79 |
+
if labels.max() >= self.config.vocab_size:
|
80 |
+
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
81 |
+
|
82 |
+
# retrieve loss input_lengths from attention_mask
|
83 |
+
attention_mask = (
|
84 |
+
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
85 |
+
)
|
86 |
+
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
87 |
+
|
88 |
+
# assuming that padded tokens are filled with -100
|
89 |
+
# when not being attended to
|
90 |
+
labels_mask = labels >= 0
|
91 |
+
target_lengths = labels_mask.sum(-1)
|
92 |
+
flattened_targets = labels.masked_select(labels_mask)
|
93 |
+
|
94 |
+
# ctc_loss doesn't support fp16
|
95 |
+
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
96 |
+
|
97 |
+
with torch.backends.cudnn.flags(enabled=False):
|
98 |
+
loss = nn.functional.ctc_loss(
|
99 |
+
log_probs,
|
100 |
+
flattened_targets,
|
101 |
+
input_lengths,
|
102 |
+
target_lengths,
|
103 |
+
blank=self.config.ctc_token_id,
|
104 |
+
reduction=self.config.ctc_loss_reduction,
|
105 |
+
zero_infinity=self.config.ctc_zero_infinity,
|
106 |
+
)
|
107 |
+
|
108 |
+
return CausalLMOutput(
|
109 |
+
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
110 |
+
)
|
asr/__init__.py
ADDED
File without changes
|
asr/run_asr.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Inference main class.
|
3 |
+
|
4 |
+
Author: Marcely Zanon Boito, 2024
|
5 |
+
"""
|
6 |
+
|
7 |
+
from .CTC_model import mHubertForCTC
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
11 |
+
from transformers import HubertConfig
|
12 |
+
|
13 |
+
from datasets import load_dataset
|
14 |
+
|
15 |
+
fbk_test_id = 'FBK-MT/Speech-MASSIVE-test'
|
16 |
+
mhubert_id = 'utter-project/mHuBERT-147'
|
17 |
+
|
18 |
+
def load_asr_model():
|
19 |
+
# Load the ASR model
|
20 |
+
tokenizer = Wav2Vec2CTCTokenizer("asr/vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
|
21 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(mhubert_id)
|
22 |
+
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
23 |
+
|
24 |
+
config = HubertConfig.from_pretrained("naver/mHuBERT-147-ASR-fr")
|
25 |
+
model = mHubertForCTC.from_pretrained("naver/mHuBERT-147-ASR-fr", config=config)
|
26 |
+
model.eval()
|
27 |
+
return model, processor
|
28 |
+
|
29 |
+
def run_asr_inference(model, processor, example):
|
30 |
+
audio = processor(example["array"], sampling_rate=example["sampling_rate"]).input_values[0]
|
31 |
+
input_values = torch.tensor(audio).unsqueeze(0)
|
32 |
+
|
33 |
+
with torch.no_grad():
|
34 |
+
logits = model(input_values).logits
|
35 |
+
|
36 |
+
pred_ids = torch.argmax(logits, dim=-1)
|
37 |
+
|
38 |
+
prediction = processor.batch_decode(pred_ids)[0].replace('[CTC]', "")
|
39 |
+
return prediction
|
40 |
+
|
41 |
+
|
42 |
+
|
asr/vocab.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"q": 0, "u": 1, "a": 2, "n": 3, "d": 4, "r": 6, "l": 7, "i": 8, "e": 9, "p": 10, "o": 11, "c": 12, "h": 13, "t": 14, "b": 15, "s": 16, "\u00e0": 17, "v": 18, "y": 19, "j": 20, "\u00e7": 21, "g": 22, "m": 23, "x": 24, "\u00f9": 25, "f": 26, "\u00e9": 27, "k": 28, "\u00ee": 29, "z": 30, "\u00f4": 31, "\u00ea": 32, "\u00e8": 33, "\u00fb": 34, "\u00e2": 35, "w": 36, "\u0153": 37, "\u00ef": 38, "\u014d": 39, "\u00eb": 40, "\u00f3": 41, "\u00fc": 42, "\u0144": 43, "\u016b": 44, "\u00e4": 45, "\u00e6": 46, "\u00ed": 47, "\u0107": 48, "\u00ec": 49, "\u00e5": 50, "\u00f8": 51, "\u00f6": 52, "\u0117": 53, "\u021b": 54, "\u0142": 55, "\u00f5": 56, "\u00e1": 57, "\u0131": 58, "\u015f": 59, "\u00fa": 60, "\u00e3": 61, "\u00f1": 62, "\u0161": 63, "\u0137": 64, "\u0101": 65, "\u00df": 66, "\u0103": 67, "\u0219": 68, "\u00f2": 69, "\u017e": 70, "\u0151": 71, "\u010d": 72, "\u044d": 73, "\u01b0": 74, "\u1edb": 75, "\u1ea5": 76, "\u00f0": 77, "\u02bd": 78, "\u02bc": 79, "\u5b87": 80, "\u6d25": 81, "\u4fdd": 82, "\u00ff": 83, "\u53b3": 84, "\u4e09": 85, "\u011f": 86, "\u015b": 87, "\u0119": 88, "\u02bf": 89, "\u0148": 90, "\u016f": 91, "1": 92, "9": 93, "8": 94, "\u016d": 95, "\u017c": 96, "\u1ea3": 97, "\u0171": 98, "0": 99, "\u0159": 100, "\u0111": 101, "\u03c4": 102, "\u1ed5": 103, "\u1eaf": 104, "\u017a": 105, "\u011b": 106, "\u0192": 107, "\u03b3": 108, "5": 109, "\u03c3": 110, "\u01ce": 111, "3": 112, "\u1ea1": 113, "\ua7a1": 114, "\u013c": 115, "\u7261": 116, "\u4e39": 117, "\u01d4": 118, "\u03b2": 119, "\u03b5": 120, "\u00fd": 121, "\u00fe": 122, "\u012b": 123, "2": 124, "\u0113": 125, "\u03c9": 126, "\u03b8": 127, "6": 128, "\u1ec5": 129, "\u1eb7": 130, "\u1eab": 131, "\u1e63": 132, "\u1fd6": 133, "\u03bf": 134, "\u03c2": 135, "\u03b1": 136, "\u03c0": 137, "\u03b4": 138, "\u03c6": 139, "4": 140, "\u1e25": 141, "\u03bb": 142, "\u03cd": 143, "\u03c1": 144, "\u03bc": 145, "\u1ecb": 146, "\u0169": 147, "\u3044": 148, "\u0294": 149, "\u05d3": 150, "\u05df": 151, "\u05d1": 152, "\u05e8": 153, "\u05d0": 154, "\u05d5": 155, "\u0165": 156, "\u05da": 157, "\u05d9": 158, "\u05d4": 159, "\u05de": 160, "\u04cc": 161, "\u1ec1": 162, "7": 163, "\u02bb": 164, "\u01eb": 165, "\u013e": 166, "\u1e0d": 167, "\u043a": 168, "\u0251": 169, "\u0105": 170, "\u03b9": 171, "|": 5, "[UNK]": 172, "[PAD]": 173, "[CTC]": 174}
|
nlu/run_nlu.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import (
|
2 |
+
AutoConfig,
|
3 |
+
AutoModelForSeq2SeqLM,
|
4 |
+
AutoTokenizer
|
5 |
+
)
|
6 |
+
|
7 |
+
from datasets import load_dataset
|
8 |
+
import torch
|
9 |
+
|
10 |
+
def load_nlu_model():
|
11 |
+
config = AutoConfig.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune")
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune")
|
13 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune", config=config)
|
14 |
+
|
15 |
+
return model, tokenizer
|
16 |
+
|
17 |
+
def run_nlu_inference(model, tokenizer, example):
|
18 |
+
print(example)
|
19 |
+
formatted_example = "Annotate: " + example
|
20 |
+
input_values = tokenizer(formatted_example, max_length=128, padding=False, truncation=True, return_tensors="pt").input_ids
|
21 |
+
|
22 |
+
with torch.no_grad():
|
23 |
+
pred_ids = model.generate(input_values)
|
24 |
+
|
25 |
+
prediction = tokenizer.decode(pred_ids[0], skip_special_tokens=True)
|
26 |
+
print(prediction)
|
27 |
+
|
28 |
+
splitted_pred = prediction.strip().split()
|
29 |
+
|
30 |
+
slots_prediction = ''
|
31 |
+
intent_prediction = ''
|
32 |
+
|
33 |
+
if len(splitted_pred) >= 2:
|
34 |
+
slots_prediction = splitted_pred[:-1]
|
35 |
+
intent_prediction = splitted_pred[-1]
|
36 |
+
if len(splitted_pred) == 1:
|
37 |
+
slots_prediction = splitted_pred
|
38 |
+
|
39 |
+
words = example.split(' ')
|
40 |
+
|
41 |
+
title_1 = '[ASR output]\n'
|
42 |
+
title_2 = '\n\n[NLU - slot filling]\n'
|
43 |
+
title_3 = '\n\n[NLU - intent classifcation]\n'
|
44 |
+
|
45 |
+
prefix_str_1 = title_1 + example + title_2
|
46 |
+
prefix_str_2 = title_3
|
47 |
+
|
48 |
+
structured_output = {
|
49 |
+
'text' : prefix_str_1 + example + prefix_str_2 + intent_prediction,
|
50 |
+
'entities': []}
|
51 |
+
|
52 |
+
structured_output['entities'].append({
|
53 |
+
'entity': 'ASR output',
|
54 |
+
'word': example,
|
55 |
+
'start': len(title_1),
|
56 |
+
'end': len(title_1) + len(example)
|
57 |
+
})
|
58 |
+
|
59 |
+
idx = len(prefix_str_1)
|
60 |
+
|
61 |
+
for slot, word in zip(slots_prediction, words):
|
62 |
+
_entity = slot
|
63 |
+
_word = word
|
64 |
+
_start = idx
|
65 |
+
_end = idx + len(word)
|
66 |
+
idx = _end + 1
|
67 |
+
|
68 |
+
structured_output['entities'].append({
|
69 |
+
'entity': _entity,
|
70 |
+
'word': _word,
|
71 |
+
'start': _start,
|
72 |
+
'end': _end
|
73 |
+
})
|
74 |
+
|
75 |
+
idx = len(prefix_str_1 + example + prefix_str_2)
|
76 |
+
|
77 |
+
if intent_prediction:
|
78 |
+
structured_output['entities'].append({
|
79 |
+
'entity': 'Classified Intent',
|
80 |
+
'word': intent_prediction,
|
81 |
+
'start': idx,
|
82 |
+
'end': idx + len(intent_prediction)
|
83 |
+
})
|
84 |
+
|
85 |
+
return structured_output
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.26.3
|
2 |
+
torch==1.13.1
|
3 |
+
transformers==4.32.0
|
4 |
+
librosa==0.10.1
|
5 |
+
soundfile==0.12.1
|
6 |
+
torchaudio
|
7 |
+
accelerate
|
resources/audios/speech_massive_samples/ar_sa_sample_audio.wav
ADDED
Binary file (449 kB). View file
|
|
resources/audios/speech_massive_samples/de_de_sample_audio.wav
ADDED
Binary file (207 kB). View file
|
|
resources/audios/speech_massive_samples/es_es_sample_audio.wav
ADDED
Binary file (351 kB). View file
|
|
resources/audios/speech_massive_samples/fr_fr_sample_audio.wav
ADDED
Binary file (328 kB). View file
|
|
resources/audios/speech_massive_samples/hu_hu_sample_audio.wav
ADDED
Binary file (294 kB). View file
|
|
resources/audios/speech_massive_samples/ko_kr_sample_audio.wav
ADDED
Binary file (219 kB). View file
|
|
resources/audios/speech_massive_samples/nl_nl_sample_audio.wav
ADDED
Binary file (219 kB). View file
|
|
resources/audios/speech_massive_samples/pl_pl_sample_audio.wav.wav
ADDED
Binary file (423 kB). View file
|
|
resources/audios/speech_massive_samples/pt_pt_sample_audio.wav
ADDED
Binary file (392 kB). View file
|
|
resources/audios/speech_massive_samples/ru_ru_sample_audio.wav
ADDED
Binary file (305 kB). View file
|
|
resources/audios/speech_massive_samples/tr_tr_sample_audio.wav
ADDED
Binary file (248 kB). View file
|
|
resources/audios/speech_massive_samples/vi_vn_sample_audio.wav
ADDED
Binary file (340 kB). View file
|
|
resources/audios/utt_1264.wav
ADDED
Binary file (273 kB). View file
|
|
resources/audios/utt_14684.wav
ADDED
Binary file (206 kB). View file
|
|
resources/audios/utt_16032.wav
ADDED
Binary file (356 kB). View file
|
|
resources/audios/utt_2414.wav
ADDED
Binary file (329 kB). View file
|
|
resources/audios/utt_286.wav
ADDED
Binary file (198 kB). View file
|
|
resources/audios/utt_3060.wav
ADDED
Binary file (376 kB). View file
|
|
resources/audios/utt_5410.wav
ADDED
Binary file (309 kB). View file
|
|
resources/audios/utt_6162.wav
ADDED
Binary file (288 kB). View file
|
|
resources/audios/utt_9137.wav
ADDED
Binary file (363 kB). View file
|
|
resources/audios/utt_9912.wav
ADDED
Binary file (332 kB). View file
|
|
resources/logos/EU_flag.jpg
ADDED
![]() |
resources/logos/FBK_logo.png
ADDED
![]() |
resources/logos/NAVERLABS_2_BLACK.png
ADDED
![]() |
resources/logos/Utter_logo.png
ADDED
![]() |