jc180 commited on
Commit
fa8a5a0
·
1 Parent(s): 35ebfd2
Files changed (1) hide show
  1. app.py +41 -3
app.py CHANGED
@@ -1,8 +1,46 @@
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
8
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ from transformers import AutoFeatureExtractor, ASTForAudioClassification
5
 
6
+ model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
7
+ model = ASTForAudioClassification.from_pretrained(model_name)
8
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
9
+
10
+ device = torch.device("cpu")
11
+ model.to(device)
12
+
13
+ def classify_sound(file_path):
14
+ wv, sr = torchaudio.load(file_path)
15
+
16
+ # Convert to mono
17
+ if waveform.shape[0] > 1:
18
+ waveform = waveform.mean(dim=0, keepdim=True)
19
+
20
+ inputs = feature_extractor(
21
+ wv.squeeze().numpy(), sampling_rate=44100, return_tensors="pt"
22
+ )
23
+
24
+ with torch.no_grad():
25
+ logits = model(**inputs).logits
26
+
27
+ probs = torch.softmax(logits, dim=-1)[0]
28
+ top5 = torch.topk(probs, k=5)
29
+
30
+ res = [
31
+ (model.config.id2label[idx.item()], round(prob.item(), 4))
32
+ for idx, prob in zip(top5.indices, top5.values)
33
+ ]
34
+ return dict(res)
35
+
36
+ demo = gr.Interface(
37
+ fn=classify_sound,
38
+ inputs=gr.audio(source="upload", type="filepath"),
39
+ outputs=gr.Label(num_top_classes=5),
40
+ title="Audio Classification with AST",
41
+ description="Upload an audio clip (speech, music, ambient sound, etc.). Model: MIT AST fine-tuned on AudioSet (10 classes).",
42
+ live=False,
43
+ )
44
 
 
45
  demo.launch()
46