jc180 commited on
Commit
c04325f
·
1 Parent(s): 0fed8b2
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -33,18 +33,18 @@ def classify_sound(file_path):
33
  for idx, prob in zip(top5.indices, top5.values)
34
  ]
35
 
36
- full_probs = {
37
- model.config.id2label[i]: round(prob.item(), 4)
38
- for i, prob in enumerate(probs)
39
- }
40
-
41
- return {
42
- "Top 5 Predictions": dict(top5_labels),
43
- "Sampling Rate": sr,
44
- "Waveform Shape": str(original_shape),
45
- "All Probabilities": full_probs
46
  }
47
 
 
 
 
 
 
 
48
 
49
  demo = gr.Interface(
50
  fn=classify_sound,
 
33
  for idx, prob in zip(top5.indices, top5.values)
34
  ]
35
 
36
+ top20 = torch.topk(probs, k=20)
37
+ top20_probs = {
38
+ model.config.id2label[idx.item()]: round(prob.item(), 4)
39
+ for idx, prob in zip(top20.indices, top20.values)
 
 
 
 
 
 
40
  }
41
 
42
+ return (
43
+ dict(top5_labels),
44
+ str(sr),
45
+ str(original_shape),
46
+ top20_probs
47
+ )
48
 
49
  demo = gr.Interface(
50
  fn=classify_sound,