jc180 commited on
Commit
0fed8b2
·
1 Parent(s): 070850a
Files changed (1) hide show
  1. app.py +21 -3
app.py CHANGED
@@ -12,6 +12,7 @@ model.to(device)
12
 
13
  def classify_sound(file_path):
14
  wv, sr = torchaudio.load(file_path)
 
15
 
16
  # Convert to mono
17
  if wv.shape[0] > 1:
@@ -27,16 +28,33 @@ def classify_sound(file_path):
27
  probs = torch.softmax(logits, dim=-1)[0]
28
  top5 = torch.topk(probs, k=5)
29
 
30
- res = [
31
  (model.config.id2label[idx.item()], round(prob.item(), 4))
32
  for idx, prob in zip(top5.indices, top5.values)
33
  ]
34
- return dict(res)
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  demo = gr.Interface(
37
  fn=classify_sound,
38
  inputs=gr.Audio(sources="upload", type="filepath"),
39
- outputs=gr.Label(num_top_classes=5),
 
 
 
 
 
40
  title="Audio Classification with AST",
41
  description="Upload an audio clip (speech, music, ambient sound, etc.). Model: MIT AST fine-tuned on AudioSet (10 classes).",
42
  live=False,
 
12
 
13
  def classify_sound(file_path):
14
  wv, sr = torchaudio.load(file_path)
15
+ original_shape = wv.shape
16
 
17
  # Convert to mono
18
  if wv.shape[0] > 1:
 
28
  probs = torch.softmax(logits, dim=-1)[0]
29
  top5 = torch.topk(probs, k=5)
30
 
31
+ top5_labels = [
32
  (model.config.id2label[idx.item()], round(prob.item(), 4))
33
  for idx, prob in zip(top5.indices, top5.values)
34
  ]
35
+
36
+ full_probs = {
37
+ model.config.id2label[i]: round(prob.item(), 4)
38
+ for i, prob in enumerate(probs)
39
+ }
40
+
41
+ return {
42
+ "Top 5 Predictions": dict(top5_labels),
43
+ "Sampling Rate": sr,
44
+ "Waveform Shape": str(original_shape),
45
+ "All Probabilities": full_probs
46
+ }
47
+
48
 
49
  demo = gr.Interface(
50
  fn=classify_sound,
51
  inputs=gr.Audio(sources="upload", type="filepath"),
52
+ outputs=[
53
+ gr.Label(label = "Top 5 Pred", num_top_classes=5),
54
+ gr.Textbox(label="Sample Rate"),
55
+ gr.Textbox(label="Waveform Shape"),
56
+ gr.JSON(label="All Class Probabilities")
57
+ ],
58
  title="Audio Classification with AST",
59
  description="Upload an audio clip (speech, music, ambient sound, etc.). Model: MIT AST fine-tuned on AudioSet (10 classes).",
60
  live=False,