Lorenzoncina commited on
Commit
b3db0b0
·
1 Parent(s): 85a5ac9

First version of FAMA models demo

Browse files
Files changed (2) hide show
  1. app.py +111 -0
  2. requirements.txt +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Description:
3
+ This script presents a Gradio demo for the ASR/ST FAMA models developed at FBK
4
+
5
+ Dependencies:
6
+ all the necessary dependencies are listed in requirements.txt
7
+
8
+ Usage:
9
+ The demo can be runned locally by installing all necessary dependencies in a python virtual env or it can be run in an HuggingFace Space
10
+
11
+ Author: Lorenzo Concina
12
+ Date: 4/6/2025
13
+ """
14
+ import os
15
+ import torch
16
+ import librosa as lb
17
+ import gradio as gr
18
+ from transformers import AutoProcessor, pipeline
19
+ from datasets import load_dataset
20
+
21
+ def load_fama(model_id, output_lang):
22
+ processor = AutoProcessor.from_pretrained(model_id)
23
+
24
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
25
+ tgt_lang = "it"
26
+
27
+ # Force the model to start with the language tag
28
+ lang_tag = "<lang:{}>".format(output_lang)
29
+ lang_tag_id = processor.tokenizer.convert_tokens_to_ids(lang_tag)
30
+
31
+ generate_kwargs = {"num_beams": 5, "no_repeat_ngram_size": 5, "forced_bos_token_id": lang_tag_id}
32
+
33
+ pipe = pipeline(
34
+ "automatic-speech-recognition",
35
+ model=model_id,
36
+ trust_remote_code=True,
37
+ torch_dtype=torch.float32,
38
+ device=device,
39
+ return_timestamps=False,
40
+ generate_kwargs=generate_kwargs
41
+ )
42
+ return pipe
43
+
44
+ def load_audio_file(audio_path):
45
+ y, sr = lb.load(audio_path, sr=16000, mono=True)
46
+ return y
47
+
48
+ def transcribe(audio, task_type, model_id, output_lang):
49
+ """
50
+ Function called by gradio interface. It runs model inference on an audio sample
51
+ """
52
+ cache_key = (model_id, output_lang)
53
+ if cache_key not in model_cache:
54
+ model_cache[cache_key] = load_fama(model_id, output_lang)
55
+
56
+ pipeline = model_cache[cache_key]
57
+
58
+ if isinstance(audio, str) and os.path.isfile(audio):
59
+ #load the audio with Librosa
60
+ utterance = load_audio_file(audio)
61
+ result = pipeline(utterance)
62
+ else:
63
+ #user used the mic
64
+ result = pipeline(audio)
65
+ return result["text"]
66
+
67
+ #available models
68
+ def update_model_options(task_type):
69
+ if task_type == "ST":
70
+ return gr.update(choices=["FBK-MT/fama-small", "FBK-MT/fama-medium"], value="FBK-MT/fama-small")
71
+ else:
72
+ return gr.update(choices=[
73
+ "FBK-MT/fama-small",
74
+ "FBK-MT/fama-medium",
75
+ "FBK-MT/fama-small-asr",
76
+ "FBK-MT/fama-medium-asr"
77
+ ], value="FBK-MT/fama-small")
78
+
79
+ # Language options (languages supported by FAMA models)
80
+ language_choices = ["en", "it"]
81
+
82
+ # Cache loaded models to avoid reloading
83
+ model_cache = {}
84
+
85
+ if __name__ == "__main__":
86
+
87
+ with gr.Blocks() as iface:
88
+ gr.Markdown("""## FAMA ASR and ST\nSimple Automatic Speech Recognition and Speech Translation demo powered by FAMA models, developed at FBK. \
89
+ More informations about FAMA models can be found here: https://huggingface.co/collections/FBK-MT/fama-683425df3fb2b3171e0cdc9e""")
90
+
91
+ with gr.Row():
92
+ audio_input = gr.Audio(type="filepath", label="Upload or record audio")
93
+ task_type_input = gr.Radio(choices=["ASR", "ST"], value="ASR", label="Select task type")
94
+
95
+ model_input = gr.Radio(choices=[
96
+ "FBK-MT/fama-small",
97
+ "FBK-MT/fama-medium",
98
+ "FBK-MT/fama-small-asr",
99
+ "FBK-MT/fama-medium-asr"
100
+ ], value="FBK-MT/fama-small", label="Select a FAMA model")
101
+
102
+ lang_input = gr.Dropdown(choices=language_choices, value="it", label="Transcription language")
103
+
104
+ output = gr.Textbox(label="Transcription")
105
+
106
+ task_type_input.change(fn=update_model_options, inputs=task_type_input, outputs=model_input)
107
+
108
+ transcribe_btn = gr.Button("Transcribe")
109
+ transcribe_btn.click(fn=transcribe, inputs=[audio_input, task_type_input, model_input, lang_input], outputs=output)
110
+
111
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.12.7
4
+ aiosignal==1.3.2
5
+ annotated-types==0.7.0
6
+ anyio==4.9.0
7
+ attrs==25.3.0
8
+ audioread==3.0.1
9
+ certifi==2025.4.26
10
+ cffi==1.17.1
11
+ charset-normalizer==3.4.2
12
+ click==8.2.1
13
+ datasets==3.6.0
14
+ decorator==5.2.1
15
+ dill==0.3.8
16
+ fastapi==0.115.12
17
+ ffmpy==0.6.0
18
+ filelock==3.18.0
19
+ frozenlist==1.6.0
20
+ fsspec==2025.3.0
21
+ gradio==5.32.1
22
+ gradio_client==1.10.2
23
+ groovy==0.1.2
24
+ h11==0.16.0
25
+ hf-xet==1.1.2
26
+ httpcore==1.0.9
27
+ httpx==0.28.1
28
+ huggingface-hub==0.32.4
29
+ idna==3.10
30
+ Jinja2==3.1.6
31
+ joblib==1.5.1
32
+ lazy_loader==0.4
33
+ librosa==0.11.0
34
+ llvmlite==0.44.0
35
+ markdown-it-py==3.0.0
36
+ MarkupSafe==3.0.2
37
+ mdurl==0.1.2
38
+ mpmath==1.3.0
39
+ msgpack==1.1.0
40
+ multidict==6.4.4
41
+ multiprocess==0.70.16
42
+ networkx==3.4.2
43
+ numba==0.61.2
44
+ numpy==2.2.6
45
+ nvidia-cublas-cu12==12.6.4.1
46
+ nvidia-cuda-cupti-cu12==12.6.80
47
+ nvidia-cuda-nvrtc-cu12==12.6.77
48
+ nvidia-cuda-runtime-cu12==12.6.77
49
+ nvidia-cudnn-cu12==9.5.1.17
50
+ nvidia-cufft-cu12==11.3.0.4
51
+ nvidia-cufile-cu12==1.11.1.6
52
+ nvidia-curand-cu12==10.3.7.77
53
+ nvidia-cusolver-cu12==11.7.1.2
54
+ nvidia-cusparse-cu12==12.5.4.2
55
+ nvidia-cusparselt-cu12==0.6.3
56
+ nvidia-nccl-cu12==2.26.2
57
+ nvidia-nvjitlink-cu12==12.6.85
58
+ nvidia-nvtx-cu12==12.6.77
59
+ orjson==3.10.18
60
+ packaging==25.0
61
+ pandas==2.2.3
62
+ pillow==11.2.1
63
+ platformdirs==4.3.8
64
+ pooch==1.8.2
65
+ propcache==0.3.1
66
+ pyarrow==20.0.0
67
+ pycparser==2.22
68
+ pydantic==2.11.5
69
+ pydantic_core==2.33.2
70
+ pydub==0.25.1
71
+ Pygments==2.19.1
72
+ python-dateutil==2.9.0.post0
73
+ python-multipart==0.0.20
74
+ pytz==2025.2
75
+ PyYAML==6.0.2
76
+ regex==2024.11.6
77
+ requests==2.32.3
78
+ rich==14.0.0
79
+ ruff==0.11.12
80
+ safehttpx==0.1.6
81
+ safetensors==0.5.3
82
+ scikit-learn==1.6.1
83
+ scipy==1.15.3
84
+ semantic-version==2.10.0
85
+ sentencepiece==0.2.0
86
+ setuptools==80.9.0
87
+ shellingham==1.5.4
88
+ six==1.17.0
89
+ sniffio==1.3.1
90
+ soundfile==0.13.1
91
+ soxr==0.5.0.post1
92
+ starlette==0.46.2
93
+ sympy==1.14.0
94
+ threadpoolctl==3.6.0
95
+ tokenizers==0.21.1
96
+ tomlkit==0.13.2
97
+ torch==2.7.0
98
+ torchaudio==2.7.0
99
+ torchvision==0.22.0
100
+ tqdm==4.67.1
101
+ transformers==4.48.1
102
+ triton==3.3.0
103
+ typer==0.16.0
104
+ typing-inspection==0.4.1
105
+ typing_extensions==4.14.0
106
+ tzdata==2025.2
107
+ urllib3==2.4.0
108
+ uvicorn==0.34.3
109
+ websockets==15.0.1
110
+ xxhash==3.5.0
111
+ yarl==1.20.0